diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 8923ab1fdec..e6d228d5da9 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -119,11 +119,10 @@ def quantize( # noqa C901 # Check for required args if group_size is None: raise Exception("For 8da4w quantization, group size must be specified.") - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - model = Int8DynActInt4WeightQuantizer( - precision=torch_dtype, groupsize=group_size - ).quantize(model) + from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ + + quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size)) if verbose: print("quantized model:", model) @@ -663,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module: def quantized_model(self) -> nn.Module: model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() - self.mod.load_state_dict(model_updated_state_dict) + self.mod.load_state_dict(model_updated_state_dict, assign=True) return self.mod