Skip to content

Commit 3b15705

Browse files
committed
parameterize nms gpu test
1 parent 136bc47 commit 3b15705

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

test/test_ops.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -750,47 +750,36 @@ def test_qnms(self, iou, scale, zero_point):
750750

751751
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
752752

753-
@needs_cuda
753+
@pytest.mark.parametrize(
754+
"device",
755+
(
756+
pytest.param("cuda", marks=pytest.mark.needs_cuda),
757+
pytest.param("mps", marks=pytest.mark.needs_mps),
758+
),
759+
)
754760
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
755-
def test_nms_cuda(self, iou, dtype=torch.float64):
761+
def test_nms_gpu(self, iou, device, dtype=torch.float64):
762+
dtype = torch.float32 if device == "mps" else dtype
756763
tol = 1e-3 if dtype is torch.half else 1e-5
757764
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
758765

759766
boxes, scores = self._create_tensors_with_iou(1000, iou)
760767
r_cpu = ops.nms(boxes, scores, iou)
761-
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
762-
763-
is_eq = torch.allclose(r_cpu, r_cuda.cpu())
764-
if not is_eq:
765-
# if the indices are not the same, ensure that it's because the scores
766-
# are duplicate
767-
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
768-
assert is_eq, err_msg.format(iou)
769-
770-
@needs_mps
771-
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
772-
def test_nms_mps(self, iou, dtype=torch.float32):
773-
tol = 1e-3 if dtype is torch.half else 1e-5
774-
err_msg = "NMS incompatible between CPU and MPS for IoU={}"
775-
776-
boxes, scores = self._create_tensors_with_iou(1000, iou)
777-
r_cpu = ops.nms(boxes, scores, iou)
778-
r_mps = ops.nms(boxes.to("mps"), scores.to("mps"), iou)
768+
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
779769

780-
print(r_cpu.size(), r_mps.size())
781-
is_eq = torch.allclose(r_cpu, r_mps.cpu())
770+
is_eq = torch.allclose(r_cpu, r_gpu.cpu())
782771
if not is_eq:
783772
# if the indices are not the same, ensure that it's because the scores
784773
# are duplicate
785-
is_eq = torch.allclose(scores[r_cpu], scores[r_mps.cpu()], rtol=tol, atol=tol)
774+
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
786775
assert is_eq, err_msg.format(iou)
787776

788777
@needs_cuda
789778
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
790779
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
791780
def test_autocast(self, iou, dtype):
792781
with torch.cuda.amp.autocast():
793-
self.test_nms_cuda(iou=iou, dtype=dtype)
782+
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
794783

795784
@pytest.mark.parametrize(
796785
"device",

0 commit comments

Comments
 (0)