@@ -240,7 +240,9 @@ def test_attention_block_default(self):
240
240
assert attention_scores .shape == (1 , 32 , 64 , 64 )
241
241
output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
242
242
243
- expected_slice = torch .tensor ([- 1.4975 , - 0.0038 , - 0.7847 , - 1.4567 , 1.1220 , - 0.8962 , - 1.7394 , 1.1319 , - 0.5427 ])
243
+ expected_slice = torch .tensor (
244
+ [- 1.4975 , - 0.0038 , - 0.7847 , - 1.4567 , 1.1220 , - 0.8962 , - 1.7394 , 1.1319 , - 0.5427 ], device = torch_device
245
+ )
244
246
assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
245
247
246
248
@@ -264,7 +266,9 @@ def test_spatial_transformer_default(self):
264
266
assert attention_scores .shape == (1 , 32 , 64 , 64 )
265
267
output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
266
268
267
- expected_slice = torch .tensor ([- 1.2447 , - 0.0137 , - 0.9559 , - 1.5223 , 0.6991 , - 1.0126 , - 2.0974 , 0.8921 , - 1.0201 ])
269
+ expected_slice = torch .tensor (
270
+ [- 1.2447 , - 0.0137 , - 0.9559 , - 1.5223 , 0.6991 , - 1.0126 , - 2.0974 , 0.8921 , - 1.0201 ], device = torch_device
271
+ )
268
272
assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
269
273
270
274
def test_spatial_transformer_context_dim (self ):
@@ -287,7 +291,9 @@ def test_spatial_transformer_context_dim(self):
287
291
assert attention_scores .shape == (1 , 64 , 64 , 64 )
288
292
output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
289
293
290
- expected_slice = torch .tensor ([- 0.2555 , - 0.8877 , - 2.4739 , - 2.2251 , 1.2714 , 0.0807 , - 0.4161 , - 1.6408 , - 0.0471 ])
294
+ expected_slice = torch .tensor (
295
+ [- 0.2555 , - 0.8877 , - 2.4739 , - 2.2251 , 1.2714 , 0.0807 , - 0.4161 , - 1.6408 , - 0.0471 ], device = torch_device
296
+ )
291
297
assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
292
298
293
299
def test_spatial_transformer_dropout (self ):
@@ -313,5 +319,7 @@ def test_spatial_transformer_dropout(self):
313
319
assert attention_scores .shape == (1 , 32 , 64 , 64 )
314
320
output_slice = attention_scores [0 , - 1 , - 3 :, - 3 :]
315
321
316
- expected_slice = torch .tensor ([- 1.2448 , - 0.0190 , - 0.9471 , - 1.5140 , 0.7069 , - 1.0144 , - 2.1077 , 0.9099 , - 1.0091 ])
322
+ expected_slice = torch .tensor (
323
+ [- 1.2448 , - 0.0190 , - 0.9471 , - 1.5140 , 0.7069 , - 1.0144 , - 2.1077 , 0.9099 , - 1.0091 ], device = torch_device
324
+ )
317
325
assert torch .allclose (output_slice .flatten (), expected_slice , atol = 1e-3 )
0 commit comments