Skip to content

Commit 761f029

Browse files
authored
[Tests] Fix spatial transformer tests on GPU (open-mmlab#531)
1 parent c1796ef commit 761f029

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

tests/test_layers_utils.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,9 @@ def test_attention_block_default(self):
240240
assert attention_scores.shape == (1, 32, 64, 64)
241241
output_slice = attention_scores[0, -1, -3:, -3:]
242242

243-
expected_slice = torch.tensor([-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427])
243+
expected_slice = torch.tensor(
244+
[-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device
245+
)
244246
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
245247

246248

@@ -264,7 +266,9 @@ def test_spatial_transformer_default(self):
264266
assert attention_scores.shape == (1, 32, 64, 64)
265267
output_slice = attention_scores[0, -1, -3:, -3:]
266268

267-
expected_slice = torch.tensor([-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201])
269+
expected_slice = torch.tensor(
270+
[-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201], device=torch_device
271+
)
268272
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
269273

270274
def test_spatial_transformer_context_dim(self):
@@ -287,7 +291,9 @@ def test_spatial_transformer_context_dim(self):
287291
assert attention_scores.shape == (1, 64, 64, 64)
288292
output_slice = attention_scores[0, -1, -3:, -3:]
289293

290-
expected_slice = torch.tensor([-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471])
294+
expected_slice = torch.tensor(
295+
[-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471], device=torch_device
296+
)
291297
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
292298

293299
def test_spatial_transformer_dropout(self):
@@ -313,5 +319,7 @@ def test_spatial_transformer_dropout(self):
313319
assert attention_scores.shape == (1, 32, 64, 64)
314320
output_slice = attention_scores[0, -1, -3:, -3:]
315321

316-
expected_slice = torch.tensor([-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091])
322+
expected_slice = torch.tensor(
323+
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
324+
)
317325
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

0 commit comments

Comments
 (0)