@@ -751,47 +751,36 @@ def test_qnms(self, iou, scale, zero_point):
751751
752752 torch .testing .assert_close (qkeep , keep , msg = err_msg .format (iou ))
753753
754- @needs_cuda
754+ @pytest .mark .parametrize (
755+ "device" ,
756+ (
757+ pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
758+ pytest .param ("mps" , marks = pytest .mark .needs_mps ),
759+ ),
760+ )
755761 @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
756- def test_nms_cuda (self , iou , dtype = torch .float64 ):
762+ def test_nms_gpu (self , iou , device , dtype = torch .float64 ):
763+ dtype = torch .float32 if device == "mps" else dtype
757764 tol = 1e-3 if dtype is torch .half else 1e-5
758765 err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
759766
760767 boxes , scores = self ._create_tensors_with_iou (1000 , iou )
761768 r_cpu = ops .nms (boxes , scores , iou )
762- r_cuda = ops .nms (boxes .cuda (), scores .cuda (), iou )
763-
764- is_eq = torch .allclose (r_cpu , r_cuda .cpu ())
765- if not is_eq :
766- # if the indices are not the same, ensure that it's because the scores
767- # are duplicate
768- is_eq = torch .allclose (scores [r_cpu ], scores [r_cuda .cpu ()], rtol = tol , atol = tol )
769- assert is_eq , err_msg .format (iou )
770-
771- @needs_mps
772- @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
773- def test_nms_mps (self , iou , dtype = torch .float32 ):
774- tol = 1e-3 if dtype is torch .half else 1e-5
775- err_msg = "NMS incompatible between CPU and MPS for IoU={}"
776-
777- boxes , scores = self ._create_tensors_with_iou (1000 , iou )
778- r_cpu = ops .nms (boxes , scores , iou )
779- r_mps = ops .nms (boxes .to ("mps" ), scores .to ("mps" ), iou )
769+ r_gpu = ops .nms (boxes .to (device ), scores .to (device ), iou )
780770
781- print (r_cpu .size (), r_mps .size ())
782- is_eq = torch .allclose (r_cpu , r_mps .cpu ())
771+ is_eq = torch .allclose (r_cpu , r_gpu .cpu ())
783772 if not is_eq :
784773 # if the indices are not the same, ensure that it's because the scores
785774 # are duplicate
786- is_eq = torch .allclose (scores [r_cpu ], scores [r_mps .cpu ()], rtol = tol , atol = tol )
775+ is_eq = torch .allclose (scores [r_cpu ], scores [r_gpu .cpu ()], rtol = tol , atol = tol )
787776 assert is_eq , err_msg .format (iou )
788777
789778 @needs_cuda
790779 @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
791780 @pytest .mark .parametrize ("dtype" , (torch .float , torch .half ))
792781 def test_autocast (self , iou , dtype ):
793782 with torch .cuda .amp .autocast ():
794- self .test_nms_cuda (iou = iou , dtype = dtype )
783+ self .test_nms_gpu (iou = iou , dtype = dtype , device = "cuda" )
795784
796785 @pytest .mark .parametrize (
797786 "device" ,
0 commit comments