diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index f523cb091c..8e2876467b 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -213,6 +213,11 @@ def forward(self, x): class TestQAT(TestCase): SEED = 123 + DEVICE = ( + torch.accelerator.current_accelerator() + if torch.accelerator.is_available() + else None + ) def test_fake_quantize_per_channel_group(self): n_bit = 4 @@ -347,7 +352,7 @@ def _set_ptq_weight( group_size, ) q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to("cuda"), + q_weight.to(self.DEVICE), qat_linear.inner_k_tiles, ) ptq_linear.weight = q_weight @@ -600,13 +605,13 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(DEVICE is None, "skipping when GPU is not available") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 inner_k_tiles = 8 scales_precision = torch.bfloat16 - device = torch.device("cuda") + device = self.DEVICE dtype = torch.bfloat16 torch.manual_seed(self.SEED) x = torch.randn(100, 256, dtype=dtype, device=device) @@ -651,13 +656,13 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(DEVICE is None, "skipping when GPU is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 - device = torch.device("cuda") + device = self.DEVICE dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( @@ -692,15 +697,19 @@ def test_qat_4w_quantizer_gradients(self): quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(DEVICE is None, "skipping when GPU is not available") + @unittest.skipIf( + DEVICE is torch.device("xpu"), + "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770", + ) def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer group_size = 32 inner_k_tiles = 8 - device = torch.device("cuda") dtype = torch.bfloat16 + device = self.DEVICE torch.manual_seed(self.SEED) m = M().to(device).to(dtype) m2 = copy.deepcopy(m) @@ -709,8 +718,7 @@ def test_qat_4w_quantizer(self): inner_k_tiles=inner_k_tiles, ) ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, - inner_k_tiles=inner_k_tiles, + groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device ) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -1891,12 +1899,12 @@ def _test_quantize_api_against_ptq( torch.manual_seed(self.SEED) if module_type == "linear": - m = M().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) + m = M().to(dtype).to(self.DEVICE) + example_inputs = (m.example_inputs()[0].to(dtype).to(self.DEVICE),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) elif module_type == "embedding": - m = M3().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].cuda(),) + m = M3().to(dtype).to(self.DEVICE) + example_inputs = (m.example_inputs()[0].to(self.DEVICE),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) else: raise ValueError(f"Unknown module type {module_type}") @@ -1971,7 +1979,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(DEVICE is None, "skipping when GPU is not available") def test_quantize_api_int8_int4(self): """ Test the following: @@ -1984,7 +1992,7 @@ def test_quantize_api_int8_int4(self): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(DEVICE is None, "skipping when GPU is not available") @parametrize( "weight_dtype, weight_granularity, dtype", [ @@ -2009,7 +2017,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): dtype=dtype, ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(DEVICE is None, "skipping when GPU is not available") @parametrize( "weight_dtype, granularity, dtype, module_type", [