@@ -750,47 +750,36 @@ def test_qnms(self, iou, scale, zero_point):
750750
751751 torch .testing .assert_close (qkeep , keep , msg = err_msg .format (iou ))
752752
753- @needs_cuda
753+ @pytest .mark .parametrize (
754+ "device" ,
755+ (
756+ pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
757+ pytest .param ("mps" , marks = pytest .mark .needs_mps ),
758+ ),
759+ )
754760 @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
755- def test_nms_cuda (self , iou , dtype = torch .float64 ):
761+ def test_nms_gpu (self , iou , device , dtype = torch .float64 ):
762+ dtype = torch .float32 if device == "mps" else dtype
756763 tol = 1e-3 if dtype is torch .half else 1e-5
757764 err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
758765
759766 boxes , scores = self ._create_tensors_with_iou (1000 , iou )
760767 r_cpu = ops .nms (boxes , scores , iou )
761- r_cuda = ops .nms (boxes .cuda (), scores .cuda (), iou )
762-
763- is_eq = torch .allclose (r_cpu , r_cuda .cpu ())
764- if not is_eq :
765- # if the indices are not the same, ensure that it's because the scores
766- # are duplicate
767- is_eq = torch .allclose (scores [r_cpu ], scores [r_cuda .cpu ()], rtol = tol , atol = tol )
768- assert is_eq , err_msg .format (iou )
769-
770- @needs_mps
771- @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
772- def test_nms_mps (self , iou , dtype = torch .float32 ):
773- tol = 1e-3 if dtype is torch .half else 1e-5
774- err_msg = "NMS incompatible between CPU and MPS for IoU={}"
775-
776- boxes , scores = self ._create_tensors_with_iou (1000 , iou )
777- r_cpu = ops .nms (boxes , scores , iou )
778- r_mps = ops .nms (boxes .to ("mps" ), scores .to ("mps" ), iou )
768+ r_gpu = ops .nms (boxes .to (device ), scores .to (device ), iou )
779769
780- print (r_cpu .size (), r_mps .size ())
781- is_eq = torch .allclose (r_cpu , r_mps .cpu ())
770+ is_eq = torch .allclose (r_cpu , r_gpu .cpu ())
782771 if not is_eq :
783772 # if the indices are not the same, ensure that it's because the scores
784773 # are duplicate
785- is_eq = torch .allclose (scores [r_cpu ], scores [r_mps .cpu ()], rtol = tol , atol = tol )
774+ is_eq = torch .allclose (scores [r_cpu ], scores [r_gpu .cpu ()], rtol = tol , atol = tol )
786775 assert is_eq , err_msg .format (iou )
787776
788777 @needs_cuda
789778 @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
790779 @pytest .mark .parametrize ("dtype" , (torch .float , torch .half ))
791780 def test_autocast (self , iou , dtype ):
792781 with torch .cuda .amp .autocast ():
793- self .test_nms_cuda (iou = iou , dtype = dtype )
782+ self .test_nms_gpu (iou = iou , dtype = dtype , device = "cuda" )
794783
795784 @pytest .mark .parametrize (
796785 "device" ,
0 commit comments