Skip to content

Fix inconsistent NMS implementation between CPU and CUDA #1556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,26 +1196,33 @@ def reference_nms(self, boxes, scores, iou_threshold):

return torch.as_tensor(picked)

def _create_tensors(self, N):
def _create_tensors_with_iou(self, N, iou_thresh):
# force last box to have a pre-defined iou with the first box
# let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
# then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
# we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += torch.rand(N, 2) * 100
boxes[:, 2:] += boxes[:, :2]
boxes[-1, :] = boxes[0, :]
x0, y0, x1, y1 = boxes[-1].tolist()
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
scores = torch.rand(N)
return boxes, scores

def test_nms(self):
boxes, scores = self._create_tensors(1000)
err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
for iou in [0.2, 0.5, 0.8]:
boxes, scores = self._create_tensors_with_iou(1000, iou)
keep_ref = self.reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda(self):
boxes, scores = self._create_tensors(1000)
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'

for iou in [0.2, 0.5, 0.8]:
boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ at::Tensor nms_cpu_kernel(
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr >= iou_threshold)
if (ovr > iou_threshold)
suppressed[j] = 1;
}
}
Expand Down
1 change: 0 additions & 1 deletion torchvision/csrc/cuda/nms_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ __global__ void nms_kernel(
at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores,
float iou_threshold) {
using scalar_t = float;
AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
Expand Down