diff --git a/tests/models/swiftformer/test_modeling_swiftformer.py b/tests/models/swiftformer/test_modeling_swiftformer.py index ca616976fd4a78..f2785a6640cae1 100644 --- a/tests/models/swiftformer/test_modeling_swiftformer.py +++ b/tests/models/swiftformer/test_modeling_swiftformer.py @@ -303,5 +303,5 @@ def test_inference_image_classification_head(self): expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([[-2.1703e00, 2.1107e00, -2.0811e00]]) + expected_slice = torch.tensor([[-2.1703e00, 2.1107e00, -2.0811e00]]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))