From 6b0ca2d6cbb338c2e72cec16fc12b63d38d96ae5 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 18 Jun 2024 12:54:49 -0400 Subject: [PATCH] fixing scripts (#395) Summary: a few bugfixes for scripts 1) convert_hf_checkpoint.py had a gpt-fast dependency that wasn't caught due to it being in the path 2) eval.py had a bug due to the switch to aqt apis 3) generate.py had a bug due to the deprecation of the old quant apis Test Plan: python eval.py python eval.py -q int8dq --compile --limit 2 python eval.py -q int8wo --compile --limit 2 python eval.py -q int4wo-64 --compile --limit 2 python eval.py -q int4wo-64-gptq --compile sh benchmarks.sh (going to add the output results once they finish Reviewers: Subscribers: Tasks: Tags: --- scripts/convert_hf_checkpoint.py | 7 +------ torchao/_models/llama/eval.py | 8 +++++--- torchao/_models/llama/generate.py | 17 +++++++++++------ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 6a7da922a..7b0f76903 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -8,17 +8,12 @@ import json import re import shutil -import sys from pathlib import Path from typing import Optional import torch -# support running without installing as a package -wd = Path(__file__).parent.parent.resolve() -sys.path.append(str(wd)) - -from model import ModelArgs +from torchao._models.llama.model import ModelArgs @torch.inference_mode() diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index cccc13f09..7842c3e66 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -13,7 +13,7 @@ ) from torchao.quantization.quant_api import ( - quantize, int4wo, int8wo, int8da_int8w + quantize, int4wo, int8wo, int8da_int8w, unwrap_tensor_subclass ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder @@ -70,7 +70,7 @@ def run_evaluation( if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - + assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}" inputs = InputRecorder( tokenizer, calibration_seq_length, @@ -83,9 +83,11 @@ def run_evaluation( calibration_limit, ).get_inputs() - quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, precision=precision) + quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize) model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length) model = quantizer.quantize(model, inputs).to(device) + else: + unwrap_tensor_subclass(model) if compile: model = torch.compile(model, mode="max-autotune", fullgraph=True) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index ea7200ea6..a230ea4a3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -189,20 +189,22 @@ def main( if quantization: from torchao.quantization.quant_api import ( - change_linear_weights_to_int4_woqtensors, - change_linear_weights_to_int8_woqtensors, - change_linear_weights_to_int8_dqtensors, + quantize, + int8wo, + int8da_int8w, + int4wo, autoquant, + unwrap_tensor_subclass ) if "int8wo" in quantization: - change_linear_weights_to_int8_woqtensors(model) + quantize(model, int8wo()) if "int8dq" in quantization: - change_linear_weights_to_int8_dqtensors(model) + quantize(model, int8da_int8w()) if "int4wo" in quantization: groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" - change_linear_weights_to_int4_woqtensors(model, groupsize=groupsize) + quantize(model, int4wo(groupsize=groupsize)) if "autoquant" == quantization: model = autoquant(model) generate( @@ -211,6 +213,9 @@ def main( 2, interactive=False ) + else: + unwrap_tensor_subclass(model) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9