Skip to content

Commit f48e988

Browse files
committed
Rename Int4WeightPreshuffledFakeQuantizeConfig
**Summary:** This config actually works for both preshuffled and plain int4 QAT, so we remove "Preshuffled" from the name. BC-breaking notes: ``` Int4WeightPreshuffledFakeQuantizeConfig -> Int4WeightFakeQuantizeConfig Int4WeightPreshuffledFakeQuantizer -> Int4WeightFakeQuantizer ``` **Test Plan:** ``` python test/quantization/test_qat.py ```
1 parent ea8c00f commit f48e988

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

test/quantization/test_qat.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151
from torchao.quantization.qat.fake_quantize_config import (
5252
Float8FakeQuantizeConfig,
53-
Int4WeightPreshuffledFakeQuantizeConfig,
53+
Int4WeightFakeQuantizeConfig,
5454
IntxFakeQuantizeConfig,
5555
)
5656
from torchao.quantization.qat.fake_quantizer import (
@@ -1985,7 +1985,7 @@ def test_infer_fp8_int4_config(self):
19851985
self.assertIsInstance(act_config, Float8FakeQuantizeConfig)
19861986
self.assertEqual(act_config.dtype, e4m3_dtype)
19871987
self.assertIsInstance(act_config.granularity, PerRow)
1988-
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
1988+
self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig)
19891989
self.assertEqual(weight_config.group_size, 128)
19901990
self.assertEqual(weight_config.activation_dtype, e4m3_dtype)
19911991

@@ -2008,7 +2008,7 @@ def test_infer_int4_weight_only_config(self):
20082008
base_config = Int4WeightOnlyConfig(version=2)
20092009
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
20102010
self.assertIsNone(act_config)
2011-
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
2011+
self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig)
20122012
self.assertEqual(weight_config.group_size, 128)
20132013
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)
20142014

@@ -2102,7 +2102,7 @@ def test_fbgemm_fp8_int4_preshuffled_primitives(self):
21022102
"""
21032103
Compare numerics between:
21042104
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
2105-
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
2105+
(2) Our reference QAT version in `Int4WeightFakeQuantizer`
21062106
"""
21072107
from fbgemm_gpu.experimental.gen_ai.quantize import (
21082108
int4_row_quantize,
@@ -2184,7 +2184,7 @@ def test_fbgemm_int4_weight_only_primitives(self):
21842184
"""
21852185
Compare numerics between:
21862186
(1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp
2187-
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
2187+
(2) Our reference QAT version in `Int4WeightFakeQuantizer`
21882188
"""
21892189
from fbgemm_gpu.experimental.gen_ai.quantize import (
21902190
int4_row_quantize_zp,

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,8 @@ def __post_init__(self):
7878
)
7979

8080

81-
# TODO: rename this config, it actually works for both plain and preshuffled
8281
@dataclass
83-
class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
82+
class Int4WeightFakeQuantizeConfig(FakeQuantizeConfigBase):
8483
"""
8584
Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel:
8685
torch.ops.fbgemm.f8i4bf16_shuffled
@@ -393,7 +392,7 @@ def _infer_fake_quantize_configs(
393392
raise ValueError(
394393
f"Packing format must be one of {supported_packing_formats}"
395394
)
396-
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
395+
weight_config = Int4WeightFakeQuantizeConfig(
397396
group_size=128,
398397
activation_dtype=torch.bfloat16,
399398
)
@@ -436,7 +435,7 @@ def _infer_fake_quantize_configs(
436435
dtype=e4m3_dtype,
437436
granularity=PerRow(),
438437
)
439-
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
438+
weight_config = Int4WeightFakeQuantizeConfig(
440439
group_size=128,
441440
activation_dtype=e4m3_dtype,
442441
)

torchao/quantization/qat/fake_quantizer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .fake_quantize_config import (
3636
FakeQuantizeConfigBase,
3737
Float8FakeQuantizeConfig,
38-
Int4WeightPreshuffledFakeQuantizeConfig,
38+
Int4WeightFakeQuantizeConfig,
3939
IntxFakeQuantizeConfig,
4040
)
4141
from .utils import (
@@ -68,8 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
6868

6969
if isinstance(config, IntxFakeQuantizeConfig):
7070
return IntxFakeQuantizer(config)
71-
elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig):
72-
return Int4WeightPreshuffledFakeQuantizer(config)
71+
elif isinstance(config, Int4WeightFakeQuantizeConfig):
72+
return Int4WeightFakeQuantizer(config)
7373
elif isinstance(config, Float8FakeQuantizeConfig):
7474
return Float8FakeQuantizer(config)
7575
elif isinstance(config, NVFP4FakeQuantizeConfig):
@@ -103,8 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103103
return dq
104104

105105

106-
# TODO: rename this, it also works for plain Int4Tensor
107-
class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
106+
class Int4WeightFakeQuantizer(FakeQuantizerBase):
108107
"""
109108
Generic module for applying int4 fake quantization to a weight tensor,
110109
targeting the following FBGEMM kernels:
@@ -113,11 +112,11 @@ class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
113112
torch.ops.fbgemm.bf16i4bf16_rowwise
114113
"""
115114

116-
def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
115+
def __init__(self, config: Int4WeightFakeQuantizeConfig):
117116
super().__init__()
118117
self.config = config
119118
torch._C._log_api_usage_once(
120-
"torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer"
119+
"torchao.quantization.qat.Int4WeightFakeQuantizer"
121120
)
122121

123122
def forward(self, w: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)