Skip to content

Commit 270a859

Browse files
committed
Fix Int4WeightEmbeddingQATQuantizer.convert path
**Summary:** Fixes the issue where `Int4WeightEmbeddingQATQuantizer`'s convert path assigned the scales and zero points to the wrong attributes ("scales" and "zeros" instead of "scale" and "zero point"), and also ensures the precisions are correctly set. **Test Plan:** python test/quantization/test_qat.py -k test_qat_4w_embedding
1 parent 5a78b70 commit 270a859

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

test/quantization/test_qat.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -788,17 +788,43 @@ def test_composable_qat_quantizer(self):
788788
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
789789
)
790790
def test_qat_4w_embedding(self):
791+
from torchao._executorch_ops import (
792+
_quantized_decomposed_quantize_per_channel_group_wrapper,
793+
)
791794
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
792795

796+
group_size = 256
793797
model = M2()
794798
x = model.example_inputs()
795799
model(*x)
796-
quantizer = Int4WeightOnlyEmbeddingQATQuantizer()
800+
quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size)
797801
prepared = quantizer.prepare(model)
802+
prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight)
798803
prepared(*x)
799804
converted = quantizer.convert(model)
800805
converted(*x)
801806

807+
# Assert the scales, zero points, and weights are correct after convert
808+
qmin, qmax = -8, 7
809+
(s, zp) = get_group_qparams_symmetric(
810+
prepared_embedding_weight,
811+
4,
812+
group_size,
813+
)
814+
zp = zp.to(torch.int32)
815+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
816+
prepared_embedding_weight,
817+
s,
818+
zp,
819+
qmin,
820+
qmax,
821+
torch.int8,
822+
group_size,
823+
)
824+
torch.testing.assert_close(converted.embedding.weight, q_weight)
825+
torch.testing.assert_close(converted.embedding.scale, s)
826+
torch.testing.assert_close(converted.embedding.zero_point, zp)
827+
802828
def test_fake_quantize_config_granularity(self):
803829
"""
804830
Test initialization and property setting of `FakeQuantizeConfig`'s granularity.

torchao/quantization/qat/embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def _convert_helper(self, module: torch.nn.Module):
245245
group_size,
246246
)
247247
quantized_embedding.weight = q_weight
248-
quantized_embedding.scales = s
249-
quantized_embedding.zeros = zp
248+
quantized_embedding.scale = s.to(scale_precision)
249+
quantized_embedding.zero_point = zp.to(zero_point_precision)
250250
else:
251251
self._convert_helper(child)
252252

0 commit comments

Comments
 (0)