Skip to content

Commit b1cf619

Browse files
committed
parameterize nms gpu test
1 parent 70f3906 commit b1cf619

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

test/test_ops.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -751,47 +751,36 @@ def test_qnms(self, iou, scale, zero_point):
751751

752752
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
753753

754-
@needs_cuda
754+
@pytest.mark.parametrize(
755+
"device",
756+
(
757+
pytest.param("cuda", marks=pytest.mark.needs_cuda),
758+
pytest.param("mps", marks=pytest.mark.needs_mps),
759+
),
760+
)
755761
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
756-
def test_nms_cuda(self, iou, dtype=torch.float64):
762+
def test_nms_gpu(self, iou, device, dtype=torch.float64):
763+
dtype = torch.float32 if device == "mps" else dtype
757764
tol = 1e-3 if dtype is torch.half else 1e-5
758765
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
759766

760767
boxes, scores = self._create_tensors_with_iou(1000, iou)
761768
r_cpu = ops.nms(boxes, scores, iou)
762-
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
763-
764-
is_eq = torch.allclose(r_cpu, r_cuda.cpu())
765-
if not is_eq:
766-
# if the indices are not the same, ensure that it's because the scores
767-
# are duplicate
768-
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol)
769-
assert is_eq, err_msg.format(iou)
770-
771-
@needs_mps
772-
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
773-
def test_nms_mps(self, iou, dtype=torch.float32):
774-
tol = 1e-3 if dtype is torch.half else 1e-5
775-
err_msg = "NMS incompatible between CPU and MPS for IoU={}"
776-
777-
boxes, scores = self._create_tensors_with_iou(1000, iou)
778-
r_cpu = ops.nms(boxes, scores, iou)
779-
r_mps = ops.nms(boxes.to("mps"), scores.to("mps"), iou)
769+
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
780770

781-
print(r_cpu.size(), r_mps.size())
782-
is_eq = torch.allclose(r_cpu, r_mps.cpu())
771+
is_eq = torch.allclose(r_cpu, r_gpu.cpu())
783772
if not is_eq:
784773
# if the indices are not the same, ensure that it's because the scores
785774
# are duplicate
786-
is_eq = torch.allclose(scores[r_cpu], scores[r_mps.cpu()], rtol=tol, atol=tol)
775+
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
787776
assert is_eq, err_msg.format(iou)
788777

789778
@needs_cuda
790779
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
791780
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
792781
def test_autocast(self, iou, dtype):
793782
with torch.cuda.amp.autocast():
794-
self.test_nms_cuda(iou=iou, dtype=dtype)
783+
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
795784

796785
@pytest.mark.parametrize(
797786
"device",

0 commit comments

Comments
 (0)