Skip to content

Commit bcd64bf

Browse files
committed
Fix Int4WeightEmbeddingQATQuantizer.convert path
**Summary:** Fixes the issue where `Int4WeightEmbeddingQATQuantizer`'s convert path assigned the scales and zero points to the wrong attributes ("scales" and "zeros" instead of "scale" and "zero point"), and also ensures the precisions are correctly set. **Test Plan:** python test/quantization/test_qat.py -k test_qat_4w_embedding
1 parent 5a78b70 commit bcd64bf

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

test/quantization/test_qat.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -789,16 +789,40 @@ def test_composable_qat_quantizer(self):
789789
)
790790
def test_qat_4w_embedding(self):
791791
from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer
792+
from torchao._executorch_ops import (
793+
_quantized_decomposed_quantize_per_channel_group_wrapper,
794+
)
792795

796+
group_size = 256
793797
model = M2()
794798
x = model.example_inputs()
795799
model(*x)
796-
quantizer = Int4WeightOnlyEmbeddingQATQuantizer()
800+
quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size)
797801
prepared = quantizer.prepare(model)
802+
prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight)
798803
prepared(*x)
799804
converted = quantizer.convert(model)
800805
converted(*x)
801806

807+
# Assert the scales, zero points, and weights are correct after convert
808+
qmin, qmax = -8, 7
809+
(s, zp) = get_group_qparams_symmetric(
810+
prepared_embedding_weight, 4, group_size,
811+
)
812+
zp = zp.to(torch.int32)
813+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
814+
prepared_embedding_weight,
815+
s,
816+
zp,
817+
qmin,
818+
qmax,
819+
torch.int8,
820+
group_size,
821+
)
822+
torch.testing.assert_close(converted.embedding.weight, q_weight)
823+
torch.testing.assert_close(converted.embedding.scale, s)
824+
torch.testing.assert_close(converted.embedding.zero_point, zp)
825+
802826
def test_fake_quantize_config_granularity(self):
803827
"""
804828
Test initialization and property setting of `FakeQuantizeConfig`'s granularity.

torchao/quantization/qat/embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def _convert_helper(self, module: torch.nn.Module):
245245
group_size,
246246
)
247247
quantized_embedding.weight = q_weight
248-
quantized_embedding.scales = s
249-
quantized_embedding.zeros = zp
248+
quantized_embedding.scale = s.to(scale_precision)
249+
quantized_embedding.zero_point = zp.to(zero_point_precision)
250250
else:
251251
self._convert_helper(child)
252252

0 commit comments

Comments
 (0)