diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index fa25b5b7ab81..8809bac25f58 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -221,7 +221,7 @@ def test_keep_modules_in_fp32(self): self.assertTrue(module.weight.dtype == torch.int8) # test if inference works. - with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + with torch.no_grad() and torch.autocast(model.device.type, dtype=torch.float16): input_dict_for_transformer = self.get_dummy_inputs() model_inputs = { k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)