diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index fcd4969bbf..4e267b7124 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -788,17 +788,43 @@ def test_composable_qat_quantizer(self): not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" ) def test_qat_4w_embedding(self): + from torchao._executorch_ops import ( + _quantized_decomposed_quantize_per_channel_group_wrapper, + ) from torchao.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer + group_size = 256 model = M2() x = model.example_inputs() model(*x) - quantizer = Int4WeightOnlyEmbeddingQATQuantizer() + quantizer = Int4WeightOnlyEmbeddingQATQuantizer(group_size) prepared = quantizer.prepare(model) + prepared_embedding_weight = copy.deepcopy(prepared.embedding.weight) prepared(*x) converted = quantizer.convert(model) converted(*x) + # Assert the scales, zero points, and weights are correct after convert + qmin, qmax = -8, 7 + (s, zp) = get_group_qparams_symmetric( + prepared_embedding_weight, + 4, + group_size, + ) + zp = zp.to(torch.int32) + q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( + prepared_embedding_weight, + s, + zp, + qmin, + qmax, + torch.int8, + group_size, + ) + torch.testing.assert_close(converted.embedding.weight, q_weight) + torch.testing.assert_close(converted.embedding.scale, s) + torch.testing.assert_close(converted.embedding.zero_point, zp) + def test_fake_quantize_config_granularity(self): """ Test initialization and property setting of `FakeQuantizeConfig`'s granularity. diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index cc63c5181d..42e9b08eed 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -245,8 +245,8 @@ def _convert_helper(self, module: torch.nn.Module): group_size, ) quantized_embedding.weight = q_weight - quantized_embedding.scales = s - quantized_embedding.zeros = zp + quantized_embedding.scale = s.to(scale_precision) + quantized_embedding.zero_point = zp.to(zero_point_precision) else: self._convert_helper(child)