diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 2eca594e4..fa66867f3 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -30,6 +30,7 @@ TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, is_fbcode, ) @@ -201,6 +202,7 @@ def test_choose_qparams_group_sym_no_clipping_err(self): self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower") + @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or higher") @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) diff --git a/torchao/utils.py b/torchao/utils.py index f1248f67a..c0b79fa71 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -26,6 +26,7 @@ "TORCH_VERSION_AT_LEAST_2_3", "TORCH_VERSION_AT_LEAST_2_4", "TORCH_VERSION_AT_LEAST_2_5", + "TORCH_VERSION_AT_LEAST_2_6", # Needs to be deprecated in the future "TORCH_VERSION_AFTER_2_2", @@ -317,6 +318,7 @@ def is_fbcode(): def torch_version_at_least(min_version): return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0 +TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0") TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0") TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0") TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0")