Skip to content

Commit 58c3064

Browse files
authored
Rename Int4WeightPreshuffledFakeQuantizeConfig (#3005)
**Summary:** This config actually works for both preshuffled and plain int4 QAT, so we remove "Preshuffled" from the name. BC-breaking notes: ``` Int4WeightPreshuffledFakeQuantizeConfig -> Int4WeightFakeQuantizeConfig Int4WeightPreshuffledFakeQuantizer -> Int4WeightFakeQuantizer ``` **Test Plan:** ``` python test/quantization/test_qat.py ```
1 parent 9a770a5 commit 58c3064

File tree

3 files changed

+14
-18
lines changed

3 files changed

+14
-18
lines changed

test/quantization/test_qat.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151
from torchao.quantization.qat.fake_quantize_config import (
5252
Float8FakeQuantizeConfig,
53-
Int4WeightPreshuffledFakeQuantizeConfig,
53+
Int4WeightFakeQuantizeConfig,
5454
IntxFakeQuantizeConfig,
5555
)
5656
from torchao.quantization.qat.fake_quantizer import (
@@ -2049,7 +2049,7 @@ def test_infer_fp8_int4_config(self):
20492049
self.assertIsInstance(act_config, Float8FakeQuantizeConfig)
20502050
self.assertEqual(act_config.dtype, e4m3_dtype)
20512051
self.assertIsInstance(act_config.granularity, PerRow)
2052-
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
2052+
self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig)
20532053
self.assertEqual(weight_config.group_size, 128)
20542054
self.assertEqual(weight_config.activation_dtype, e4m3_dtype)
20552055

@@ -2072,7 +2072,7 @@ def test_infer_int4_weight_only_config(self):
20722072
base_config = Int4WeightOnlyConfig(version=2)
20732073
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
20742074
self.assertIsNone(act_config)
2075-
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
2075+
self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig)
20762076
self.assertEqual(weight_config.group_size, 128)
20772077
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)
20782078

@@ -2166,7 +2166,7 @@ def test_fbgemm_fp8_int4_preshuffled_primitives(self):
21662166
"""
21672167
Compare numerics between:
21682168
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
2169-
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
2169+
(2) Our reference QAT version in `Int4WeightFakeQuantizer`
21702170
"""
21712171
from fbgemm_gpu.experimental.gen_ai.quantize import (
21722172
int4_row_quantize,
@@ -2248,7 +2248,7 @@ def test_fbgemm_int4_weight_only_primitives(self):
22482248
"""
22492249
Compare numerics between:
22502250
(1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp
2251-
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
2251+
(2) Our reference QAT version in `Int4WeightFakeQuantizer`
22522252
"""
22532253
from fbgemm_gpu.experimental.gen_ai.quantize import (
22542254
int4_row_quantize_zp,

torchao/quantization/qat/fake_quantize_config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,8 @@ def __post_init__(self):
7878
)
7979

8080

81-
# TODO: rename this config, it actually works for both plain and preshuffled
8281
@dataclass
83-
class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
82+
class Int4WeightFakeQuantizeConfig(FakeQuantizeConfigBase):
8483
"""
8584
Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel:
8685
torch.ops.fbgemm.f8i4bf16_shuffled
@@ -395,7 +394,7 @@ def _infer_fake_quantize_configs(
395394
raise ValueError(
396395
f"Packing format must be one of {supported_packing_formats}"
397396
)
398-
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
397+
weight_config = Int4WeightFakeQuantizeConfig(
399398
group_size=128,
400399
activation_dtype=torch.bfloat16,
401400
)
@@ -438,7 +437,7 @@ def _infer_fake_quantize_configs(
438437
dtype=e4m3_dtype,
439438
granularity=PerRow(),
440439
)
441-
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
440+
weight_config = Int4WeightFakeQuantizeConfig(
442441
group_size=128,
443442
activation_dtype=e4m3_dtype,
444443
)

torchao/quantization/qat/fake_quantizer.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .fake_quantize_config import (
3636
FakeQuantizeConfigBase,
3737
Float8FakeQuantizeConfig,
38-
Int4WeightPreshuffledFakeQuantizeConfig,
38+
Int4WeightFakeQuantizeConfig,
3939
IntxFakeQuantizeConfig,
4040
)
4141
from .utils import (
@@ -68,8 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":
6868

6969
if isinstance(config, IntxFakeQuantizeConfig):
7070
return IntxFakeQuantizer(config)
71-
elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig):
72-
return Int4WeightPreshuffledFakeQuantizer(config)
71+
elif isinstance(config, Int4WeightFakeQuantizeConfig):
72+
return Int4WeightFakeQuantizer(config)
7373
elif isinstance(config, Float8FakeQuantizeConfig):
7474
return Float8FakeQuantizer(config)
7575
elif isinstance(config, NVFP4FakeQuantizeConfig):
@@ -103,8 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
103103
return dq
104104

105105

106-
# TODO: rename this, it also works for plain Int4Tensor
107-
class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
106+
class Int4WeightFakeQuantizer(FakeQuantizerBase):
108107
"""
109108
Generic module for applying int4 fake quantization to a weight tensor,
110109
targeting the following FBGEMM kernels:
@@ -113,12 +112,10 @@ class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
113112
torch.ops.fbgemm.bf16i4bf16_rowwise
114113
"""
115114

116-
def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
115+
def __init__(self, config: Int4WeightFakeQuantizeConfig):
117116
super().__init__()
118117
self.config = config
119-
torch._C._log_api_usage_once(
120-
"torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer"
121-
)
118+
torch._C._log_api_usage_once("torchao.quantization.qat.Int4WeightFakeQuantizer")
122119

123120
def forward(self, w: torch.Tensor) -> torch.Tensor:
124121
if self.config.activation_dtype == torch.float8_e4m3fn:

0 commit comments

Comments
 (0)