Skip to content

Commit 136bc47

Browse files
committed
Fix dtype parameters
1 parent 66a00fc commit 136bc47

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

test/test_ops.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,22 @@ class RoIOpTester(ABC):
102102
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
103103
@pytest.mark.parametrize("contiguous", (True, False))
104104
@pytest.mark.parametrize(
105-
"dtype",
105+
"x_dtype",
106106
(
107107
torch.float16,
108108
torch.float32,
109109
torch.float64,
110110
),
111111
ids=str,
112112
)
113-
def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs):
114-
if device == "mps" and dtype is torch.float64:
113+
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
114+
if device == "mps" and x_dtype is torch.float64:
115115
pytest.skip("MPS does not support float64")
116116

117+
rois_dtype = x_dtype if rois_dtype is None else rois_dtype
118+
117119
tol = 1e-5
118-
if dtype is torch.half:
120+
if x_dtype is torch.half:
119121
if device == "mps":
120122
tol = 5e-3
121123
else:
@@ -124,12 +126,12 @@ def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs)
124126
pool_size = 5
125127
# n_channels % (pool_size ** 2) == 0 required for PS operations.
126128
n_channels = 2 * (pool_size**2)
127-
x = torch.rand(2, n_channels, 10, 10, dtype=dtype, device=device)
129+
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
128130
if not contiguous:
129131
x = x.permute(0, 1, 3, 2)
130132
rois = torch.tensor(
131133
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
132-
dtype=dtype,
134+
dtype=rois_dtype,
133135
device=device,
134136
)
135137

@@ -139,7 +141,7 @@ def test_forward(self, device, contiguous, dtype, deterministic=False, **kwargs)
139141
# the following should be true whether we're running an autocast test or not.
140142
assert y.dtype == x.dtype
141143
gt_y = self.expected_fn(
142-
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=dtype, **kwargs
144+
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
143145
)
144146

145147
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@@ -460,17 +462,17 @@ def test_boxes_shape(self):
460462

461463
@pytest.mark.parametrize("aligned", (True, False))
462464
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
463-
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64), ids=str)
465+
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
464466
@pytest.mark.parametrize("contiguous", (True, False))
465467
@pytest.mark.parametrize("deterministic", (True, False))
466-
def test_forward(self, device, contiguous, deterministic, aligned, dtype):
468+
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype):
467469
if deterministic and device == "cpu":
468470
pytest.skip("cpu is always deterministic, don't retest")
469471
super().test_forward(
470472
device=device,
471473
contiguous=contiguous,
472474
deterministic=deterministic,
473-
dtype=dtype,
475+
x_dtype=x_dtype,
474476
aligned=aligned,
475477
)
476478

0 commit comments

Comments
 (0)