diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 3f3a1a14e..6b90b38a9 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -956,6 +956,20 @@ def test_weight_only_groupwise_quant(self): sqnr = compute_error(y_ref, y_wo) self.assertGreater(sqnr, 45.0) + def test_weight_only_groupwise_embedding_quant(self): + group_size = 64 + m = nn.Embedding(4096, 128) + input = torch.randint(0, 4096, (1, 6)) + + quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding)) + y_q = m(input) + y_ref = m.weight.dequantize()[input] + + sqnr = compute_error(y_ref, y_q) + + self.assertGreater(sqnr, 45.0) + + @parameterized.expand(COMMON_DEVICE_DTYPE) @torch.no_grad() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index b25d9dcfa..863550481 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -237,6 +237,8 @@ def main( quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) + if "embed-int8wo" in quantization: + quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding)) if quantization.startswith("awq"): from torchao._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 @@ -463,7 +465,7 @@ def callback(x): parser.add_argument('-q', '--quantization', type=str, help=( 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' - +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant' + +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, embed-int8wo' ) ) parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 944106767..ed470e69f 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1840,6 +1840,32 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) +@implements(torch.nn.functional.embedding) +def _(func, types, args, kwargs): + # new_arg1 = args[1].dequantize() + # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) + assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 + idx = args[0] + int_data, scale, zero_point = args[1].tensor_impl.get_plain() + + sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] + # Block size is expecting 2 dimensions [1, group size] but + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # we need to increase block size to correct dim + new_blocks = idx.dim()-1 + return dequantize_affine( + sliced_data, + new_blocks*[1]+list(args[1].block_size), + sliced_scale, + sliced_zero_point, + sliced_data.dtype, + args[1].quant_min, + args[1].quant_max, + args[1].zero_point_domain, + output_dtype=sliced_scale.dtype, + ) + @implements(aten.addmm.default) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = (