diff --git a/test/test_ops.py b/test/test_ops.py index e59df26702f..f4440869896 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1196,26 +1196,33 @@ def reference_nms(self, boxes, scores, iou_threshold): return torch.as_tensor(picked) - def _create_tensors(self, N): + def _create_tensors_with_iou(self, N, iou_thresh): + # force last box to have a pre-defined iou with the first box + # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1], + # then, in order to satisfy ops.iou(b0, b1) == iou_thresh, + # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh boxes = torch.rand(N, 4) * 100 - boxes[:, 2:] += torch.rand(N, 2) * 100 + boxes[:, 2:] += boxes[:, :2] + boxes[-1, :] = boxes[0, :] + x0, y0, x1, y1 = boxes[-1].tolist() + boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh scores = torch.rand(N) return boxes, scores def test_nms(self): - boxes, scores = self._create_tensors(1000) err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}' for iou in [0.2, 0.5, 0.8]: + boxes, scores = self._create_tensors_with_iou(1000, iou) keep_ref = self.reference_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou)) @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_nms_cuda(self): - boxes, scores = self._create_tensors(1000) err_msg = 'NMS incompatible between CPU and CUDA for IoU={}' for iou in [0.2, 0.5, 0.8]: + boxes, scores = self._create_tensors_with_iou(1000, iou) r_cpu = ops.nms(boxes, scores, iou) r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou) diff --git a/torchvision/csrc/cpu/nms_cpu.cpp b/torchvision/csrc/cpu/nms_cpu.cpp index 962229119de..47b771fa943 100644 --- a/torchvision/csrc/cpu/nms_cpu.cpp +++ b/torchvision/csrc/cpu/nms_cpu.cpp @@ -61,7 +61,7 @@ at::Tensor nms_cpu_kernel( auto h = std::max(static_cast(0), yy2 - yy1); auto inter = w * h; auto ovr = inter / (iarea + areas[j] - inter); - if (ovr >= iou_threshold) + if (ovr > iou_threshold) suppressed[j] = 1; } } diff --git a/torchvision/csrc/cuda/nms_cuda.cu b/torchvision/csrc/cuda/nms_cuda.cu index 4ffda3f58d0..c5f6edb1b90 100644 --- a/torchvision/csrc/cuda/nms_cuda.cu +++ b/torchvision/csrc/cuda/nms_cuda.cu @@ -72,7 +72,6 @@ __global__ void nms_kernel( at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold) { - using scalar_t = float; AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor"); AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor"); at::cuda::CUDAGuard device_guard(dets.device());