@@ -102,20 +102,22 @@ class RoIOpTester(ABC):
102102 @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
103103 @pytest .mark .parametrize ("contiguous" , (True , False ))
104104 @pytest .mark .parametrize (
105- "dtype " ,
105+ "x_dtype " ,
106106 (
107107 torch .float16 ,
108108 torch .float32 ,
109109 torch .float64 ,
110110 ),
111111 ids = str ,
112112 )
113- def test_forward (self , device , contiguous , dtype , deterministic = False , ** kwargs ):
114- if device == "mps" and dtype is torch .float64 :
113+ def test_forward (self , device , contiguous , x_dtype , rois_dtype = None , deterministic = False , ** kwargs ):
114+ if device == "mps" and x_dtype is torch .float64 :
115115 pytest .skip ("MPS does not support float64" )
116116
117+ rois_dtype = x_dtype if rois_dtype is None else rois_dtype
118+
117119 tol = 1e-5
118- if dtype is torch .half :
120+ if x_dtype is torch .half :
119121 if device == "mps" :
120122 tol = 5e-3
121123 else :
@@ -124,12 +126,12 @@ def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs)
124126 pool_size = 5
125127 # n_channels % (pool_size ** 2) == 0 required for PS operations.
126128 n_channels = 2 * (pool_size ** 2 )
127- x = torch .rand (2 , n_channels , 10 , 10 , dtype = dtype , device = device )
129+ x = torch .rand (2 , n_channels , 10 , 10 , dtype = x_dtype , device = device )
128130 if not contiguous :
129131 x = x .permute (0 , 1 , 3 , 2 )
130132 rois = torch .tensor (
131133 [[0 , 0 , 0 , 9 , 9 ], [0 , 0 , 5 , 4 , 9 ], [0 , 5 , 5 , 9 , 9 ], [1 , 0 , 0 , 9 , 9 ]], # format is (xyxy)
132- dtype = dtype ,
134+ dtype = rois_dtype ,
133135 device = device ,
134136 )
135137
@@ -139,7 +141,7 @@ def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs)
139141 # the following should be true whether we're running an autocast test or not.
140142 assert y .dtype == x .dtype
141143 gt_y = self .expected_fn (
142- x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , device = device , dtype = dtype , ** kwargs
144+ x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , device = device , dtype = x_dtype , ** kwargs
143145 )
144146
145147 torch .testing .assert_close (gt_y .to (y ), y , rtol = tol , atol = tol )
@@ -460,17 +462,17 @@ def test_boxes_shape(self):
460462
461463 @pytest .mark .parametrize ("aligned" , (True , False ))
462464 @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
463- @pytest .mark .parametrize ("dtype " , (torch .float16 , torch .float32 , torch .float64 ), ids = str )
465+ @pytest .mark .parametrize ("x_dtype " , (torch .float16 , torch .float32 , torch .float64 ), ids = str )
464466 @pytest .mark .parametrize ("contiguous" , (True , False ))
465467 @pytest .mark .parametrize ("deterministic" , (True , False ))
466- def test_forward (self , device , contiguous , deterministic , aligned , dtype ):
468+ def test_forward (self , device , contiguous , deterministic , aligned , x_dtype ):
467469 if deterministic and device == "cpu" :
468470 pytest .skip ("cpu is always deterministic, don't retest" )
469471 super ().test_forward (
470472 device = device ,
471473 contiguous = contiguous ,
472474 deterministic = deterministic ,
473- dtype = dtype ,
475+ x_dtype = x_dtype ,
474476 aligned = aligned ,
475477 )
476478
0 commit comments