From cdd658b6c7420355e2bdd0ac20f80bca333cf1fe Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 18 Sep 2024 12:37:23 +0800 Subject: [PATCH 1/4] refine autoawq exporting code --- .../export/export_to_autogptq/export.py | 2 +- auto_round/export/export_to_awq/export.py | 83 +++++++++++++------ 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index f051304a..7f8a6800 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -37,7 +37,7 @@ import torch from auto_round.utils import check_to_quantized, get_block_names, \ - get_module, logger, get_layer_names_in_block, set_module + get_module, logger, set_module import copy import json import os diff --git a/auto_round/export/export_to_awq/export.py b/auto_round/export/export_to_awq/export.py index 00de4a5c..83f6824e 100644 --- a/auto_round/export/export_to_awq/export.py +++ b/auto_round/export/export_to_awq/export.py @@ -35,10 +35,46 @@ import torch import torch.nn as nn from auto_round.export.register import register_format -from auto_round.utils import convert_dtype_torch2str_hf, logger +from auto_round.utils import convert_dtype_torch2str_hf, logger, get_module, set_module import copy import json from typing import Dict, List, Optional, Union +from .utils import WQLinear_GEMM, clear_memory, get_self_modules +from concurrent.futures import ThreadPoolExecutor +import threadpoolctl as tctl +from tqdm import tqdm + + +def pack_layer(name, model, layer_config, backend, pbar): + with tctl.threadpool_limits(limits=1): + pbar.set_description(f"packing {name}") + if name == "lm_head": ##dese not support lm-head + pbar.update(1) + return + config = layer_config[name] + if config["bits"] > 8: + pbar.update(1) + return + scale, zp = config["scale"], config["zp"] + scale = scale.t().contiguous() + zp = zp.t().contiguous() + config["zp"] = config["zp"].to(torch.float32) + bits = config["bits"] + group_size = config["group_size"] + linear_layer = get_module(model, name) + q_linear = WQLinear_GEMM.from_linear( + linear=linear_layer, + w_bit=bits, + group_size=group_size, + init_only=False, + scales=scale, + zeros=zp, + ) + linear_layer.cpu() + q_linear.to("cpu") + set_module(model, name, q_linear) + clear_memory() + pbar.update(1) @register_format("auto_awq") @@ -67,36 +103,30 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs): else: compressed_model = copy.deepcopy(model.to("cpu")) - from .utils import WQLinear_GEMM, clear_memory, get_self_modules + names = list(layer_config.keys()) - q_linear_module = WQLinear_GEMM self_modules = get_self_modules(compressed_model) + layers = [] for i in range(len(self_modules)): module = self_modules[i] named_linears = get_named_linears(module) for name, linear_layer in named_linears.items(): key = get_module_name(compressed_model, linear_layer) - logger.info(f"packing {name}") + layers.append(key) config = layer_config[key] if config["bits"] > 8: modules_to_not_convert.append(name) - continue - config["zp"] = config["zp"].to(torch.float32) - scale, zp = config["scale"], config["zp"] - scale = scale.t().contiguous() - zp = zp.t().contiguous() - q_linear = q_linear_module.from_linear( - linear=linear_layer, - w_bit=bits, - group_size=group_size, - init_only=False, - scales=scale, - zeros=zp, - ) - linear_layer.cpu() - q_linear.to(next(module.parameters()).device) - set_op_by_name(module, name, q_linear) - clear_memory() + + backend = None + with ThreadPoolExecutor(max_workers=2) as executor: + with tqdm(total=len(names), leave=True) as pbar: + def wrapper(name): + pack_layer(name, model, layer_config, backend, pbar) + + for _ in executor.map(wrapper, names): + pass + if output_dir is None: + return model quant_config = {} quant_config["quant_method"] = "awq" @@ -123,11 +153,11 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs): def save_quantized( - model, - save_dir, - quant_config, - safetensors=True, - shard_size="5GB", + model, + save_dir, + quant_config, + safetensors=True, + shard_size="5GB", ): save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir @@ -220,4 +250,3 @@ def get_module_name(model, module_to_find): if module is module_to_find: return name return None - From 0a8049f0699889e0998e7f6d12b2ae4203eae07d Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 25 Sep 2024 09:29:38 +0800 Subject: [PATCH 2/4] added VLM support --- README.md | 10 +++++----- auto_round/autoround.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 8173e6eb..4523dbd9 100644 --- a/README.md +++ b/README.md @@ -26,11 +26,11 @@ more accuracy data and recipes across various models.
## What's New - +* [2024/09] AutoRound format supports several LVM models, check out the examples [Qwen2-Vl](./examples/multimodal-modeling/Qwen-VL),[Phi-3-vision](./examples/multimodal-modeling/Phi-3-vision), [Llava](./examples/multimodal-modeling/Llava) * [2024/08] AutoRound format supports Intel Gaudi2 devices. For an example, please refer to [Intel/Qwen2-7B-int4-inc](https://huggingface.co/Intel/Qwen2-7B-int4-inc). -* [2024/08] AutoRound includes several experimental features, e.g., activation quantization, mx_fp data type, and fast - tuning of norm/bias parameters. +* [2024/08] AutoRound introduces several experimental features, including fast tuning of norm/bias parameters (for 2-bit + and W4A4), activation quantization, and the mx_fp data type. * [2024/07] Important change: the default value of nsamples has been changed from 512 to 128 to reduce the memory usages, which may cause a slight accuracy drop in some scenarios @@ -173,7 +173,7 @@ We provide two recipes for best accuracy and fast running speed with low memory. #### Formats -**AutoRound format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4] +**AutoRound Format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4] bits are supported. It resolves the asymmetric quantization kernel issues found in the AutoGPTQ format and supports both LM-head quantization and mixed precision. However, it has not yet gained widespread community adoption. For CUDA support, you will need to @@ -186,7 +186,7 @@ asymmetric kernel has issues** that can cause considerable accuracy drops, parti models. Additionally, symmetric quantization tends to perform poorly at 2-bit precision. -**AutoAWQ format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted +**AutoAWQ Format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted within the community, only 4-bits quantization is supported. Asymmetric quantization typically improves accuracy but may reduce inference speed. It features specialized layer fusion tailored for Llama models. diff --git a/auto_round/autoround.py b/auto_round/autoround.py index d87e46b1..449b9ec0 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -1176,7 +1176,7 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k "the AutoRound format (2 bits) to enhance performance." ) if "awq" in format and not self.bits == 4: - raise ValueError("The AWQ format only supports W4 asym quantization ") + raise ValueError("The AWQ format only supports W4 quantization ") serialization_keys = [ "bits", From 530b7227dae25cc304eb321fc2e49601a7d86432 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 25 Sep 2024 09:51:49 +0800 Subject: [PATCH 3/4] add integration --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 4523dbd9..5f043d2f 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,19 @@ release most of the models ourselves. | bigscience/bloom-3b | [accuracy](./docs/bloom-3B-acc.md), [recipe](./examples/language-modeling/scripts/bloom-3b.sh), [example](./examples/language-modeling/) | | EleutherAI/gpt-j-6b | [accuracy](./docs/gpt-j-6B-acc.md), [recipe](./examples/language-modeling/scripts/gpt-j-6b.sh), [example](./examples/language-modeling/) | + +## Integration +AutoRound has been integrated into multiple repositories. + +[ModelCloud/GPTQModel](https://github.com/ModelCloud/GPTQModel) + +[Intel Neural Compressor](https://github.com/intel/neural-compressor) + +[pytorch/ao](https://github.com/pytorch/ao) + + + + ## Reference If you find AutoRound useful for your research, please cite our paper: From f6ce98e94454a280580cfa2cb0d4286c34d63998 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 25 Sep 2024 09:59:31 +0800 Subject: [PATCH 4/4] alphabet --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 5f043d2f..6f0c102d 100644 --- a/README.md +++ b/README.md @@ -312,10 +312,10 @@ release most of the models ourselves. ## Integration AutoRound has been integrated into multiple repositories. -[ModelCloud/GPTQModel](https://github.com/ModelCloud/GPTQModel) - [Intel Neural Compressor](https://github.com/intel/neural-compressor) +[ModelCloud/GPTQModel](https://github.com/ModelCloud/GPTQModel) + [pytorch/ao](https://github.com/pytorch/ao)