Skip to content

Commit 084eaff

Browse files
committed
Support QAT int4 v1 path for BC
**Summary:** `Int4WeightOnlyConfig` supports version 1 (targeting tinygemm) and version 2 (targeting fbgemm). However, the latter requires a new dependency (fbgemm_gpu_genai >= 1.2.0), which is problematic for torchao integrations with other frameworks. For now, we should continue to support the v1 path for BC. **Test Plan:** ``` python test/quantization/test_qat.py -k test_infer_int4_weight_only_config ```
1 parent 6f035e8 commit 084eaff

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

test/quantization/test_qat.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchao.quantization.quant_api import (
7070
Float8DynamicActivationFloat8WeightConfig,
7171
Float8DynamicActivationInt4WeightConfig,
72+
Int4WeightOnlyConfig,
7273
Int8DynamicActivationInt4WeightConfig,
7374
)
7475
from torchao.quantization.quant_primitives import (
@@ -1932,7 +1933,6 @@ def test_quantize_api_fp8_int4(self):
19321933
target_convert_sqnr=float("inf"),
19331934
)
19341935

1935-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
19361936
def test_infer_fp8_int4_config(self):
19371937
"""
19381938
Test that fake quantize configs are correctly inferred from
@@ -1952,6 +1952,29 @@ def test_infer_fp8_int4_config(self):
19521952
self.assertEqual(weight_config.group_size, 128)
19531953
self.assertTrue(weight_config.is_symmetric)
19541954

1955+
def test_infer_int4_weight_only_config(self):
1956+
"""
1957+
Test that fake quantize configs are correctly inferred from `Int4WeightOnlyConfig`.
1958+
"""
1959+
from torchao.quantization.qat.fake_quantize_config import (
1960+
_infer_fake_quantize_configs,
1961+
)
1962+
1963+
base_config = Int4WeightOnlyConfig(version=1)
1964+
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
1965+
self.assertIsNone(act_config)
1966+
self.assertIsInstance(weight_config, IntxFakeQuantizeConfig)
1967+
self.assertEqual(weight_config.dtype, torch.uint4)
1968+
self.assertEqual(weight_config.group_size, 128)
1969+
self.assertFalse(weight_config.is_symmetric)
1970+
1971+
base_config = Int4WeightOnlyConfig(version=2)
1972+
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
1973+
self.assertIsNone(act_config)
1974+
self.assertEqual(weight_config.dtype, torch.int4)
1975+
self.assertEqual(weight_config.group_size, 128)
1976+
self.assertTrue(weight_config.is_symmetric)
1977+
19551978
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
19561979
def test_quantize_api_nvfp4(self):
19571980
"""

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,23 @@ def _infer_fake_quantize_configs(
358358
is_symmetric=base_config.mapping_type == MappingType.SYMMETRIC,
359359
)
360360
elif isinstance(base_config, Int4WeightOnlyConfig):
361-
if base_config.version != 2:
362-
raise ValueError(f"Only version 2 of {type(base_config)} is supported")
363361
act_config = None
364-
weight_config = IntxFakeQuantizeConfig(
365-
dtype=torch.int4,
366-
group_size=base_config.group_size,
367-
is_symmetric=True,
368-
)
362+
if base_config.version == 2:
363+
weight_config = IntxFakeQuantizeConfig(
364+
dtype=torch.int4,
365+
group_size=base_config.group_size,
366+
is_symmetric=True,
367+
)
368+
elif base_config.version == 1:
369+
# For BC
370+
weight_config = IntxFakeQuantizeConfig(
371+
dtype=torch.uint4,
372+
group_size=base_config.group_size,
373+
is_symmetric=False,
374+
zero_point_domain=base_config.zero_point_domain,
375+
)
376+
else:
377+
raise ValueError(f"Unknown version on base config {type(base_config)}")
369378
elif isinstance(base_config, Float8DynamicActivationFloat8WeightConfig):
370379
if base_config.version != 2:
371380
raise ValueError(f"Only version 2 of {type(base_config)} is supported")

0 commit comments

Comments
 (0)