From ca6b24650e397041a8b9db5fb4dad2aac8784fa9 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 14 Feb 2025 11:51:55 -0800 Subject: [PATCH 1/2] Switch to new ao quant api for 8da4w --- examples/models/llama/source_transformation/quantize.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 8923ab1fdec..c1b39fced70 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -119,11 +119,12 @@ def quantize( # noqa C901 # Check for required args if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=group_size - ).quantize(model) + from torchao.quantization import ( + quantize_, + int8_dynamic_activation_int4_weight, + ) + quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) if verbose: print("quantized model:", model) From 285d20bfeb1db17427e83ef453d221a1de85b44d Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 25 Feb 2025 12:09:51 -0800 Subject: [PATCH 2/2] Fix quantize embedding error --- examples/models/llama/source_transformation/quantize.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index c1b39fced70..e6d228d5da9 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -120,10 +120,8 @@ def quantize( # noqa C901 if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") - from torchao.quantization import ( - quantize_, - int8_dynamic_activation_int4_weight, - ) + from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ + quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) if verbose: @@ -664,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module: def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() - self.mod.load_state_dict(model_updated_state_dict) + self.mod.load_state_dict(model_updated_state_dict, assign=True) return self.mod