Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support autoawq format #115

Merged
merged 48 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
c7770af
add file
yintong-lu May 21, 2024
008899d
add file
yintong-lu May 21, 2024
49deb12
update
yintong-lu May 21, 2024
b1b729d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
9e6cf36
update
yintong-lu May 23, 2024
86b9b96
fix import error
yintong-lu May 23, 2024
10a91e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
4807994
minor change
yintong-lu May 23, 2024
5822543
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
991b5e9
update
yintong-lu May 23, 2024
4051f55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
f966812
fix pylint error
yintong-lu May 23, 2024
5a1188c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
2942bf9
Merge branch 'main' into lyt/autoawq
yintong-lu May 24, 2024
9c4f73b
minor fix
yintong-lu May 24, 2024
05d6e1d
fix pylint issue
yintong-lu May 24, 2024
72b48b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
40e8ce7
update import of awq
yintong-lu May 24, 2024
55b05e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
07c0f18
fix import error
yintong-lu May 24, 2024
8184d84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
02687ac
modify code
yintong-lu May 27, 2024
b4b6fa3
Merge branch 'main' into lyt/autoawq
yintong-lu May 27, 2024
6d6b8b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
e91fb20
fix doc typo
yintong-lu May 27, 2024
a775b21
fix mixtral issue
yintong-lu May 27, 2024
6e76f05
Merge branch 'main' into lyt/autoawq
yintong-lu May 27, 2024
fc39b70
mv awq to autoround format, fix mixtral issue and follow autoround fo…
yintong-lu May 28, 2024
3480b8d
minor fix
yintong-lu May 28, 2024
406410c
fix conflicts
yintong-lu May 29, 2024
949a262
fix conflicts
yintong-lu Jun 11, 2024
81576c5
mv to autoround format
yintong-lu Jun 14, 2024
1267b5f
Merge branch 'main' into lyt/autoawq
yintong-lu Jun 25, 2024
a5ceb0b
minor fix
yintong-lu Jun 25, 2024
7d505a0
typo
yintong-lu Jun 25, 2024
56e60d7
remove comments
yintong-lu Jun 25, 2024
3385eb9
update comments
yintong-lu Jun 28, 2024
5ec30d4
move awq to autoround format evaluation
yintong-lu Jul 3, 2024
d5537f4
pylint error fixing
yintong-lu Jul 3, 2024
821a1b9
fix conflicts, and add awq format
yintong-lu Jul 24, 2024
7980a58
fix conflicts
yintong-lu Jul 24, 2024
5b271ef
Merge branch 'main' into lyt/autoawq
yintong-lu Jul 24, 2024
1f87cad
add ut
yintong-lu Jul 24, 2024
c888e12
add requirement
yintong-lu Jul 24, 2024
1bc88db
add coverage test waiver
yintong-lu Jul 24, 2024
bee724a
minor change
yintong-lu Jul 24, 2024
949c48a
refine code to decrease number of branches
yintong-lu Jul 25, 2024
7794386
add ut, fix minor issues
yintong-lu Jul 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions auto_round/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
from .register import EXPORT_FORMAT
from .export_to_autogptq import save_quantized_as_autogptq
from .export_to_itrex import save_quantized_as_itrex, QuantConfig
from .export_to_awq import save_quantized_as_autoawq
from .export_to_autoround.export_to_autoround import save_quantized_as_autoround
from .export_to_autoround import AutoHfQuantizer
4 changes: 2 additions & 2 deletions auto_round/export/export_to_autoround/autoround_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,13 +361,13 @@ def _replace_by_quant_layers(self, module: nn.Module, layer_configs, backend):
in_features = layer.weight.shape[0]
out_features = layer.weight.shape[1]
bias = layer.bias is not None
new_layer = QuantLinear(
new_layer = QuantLinear( # pylint: disable=E1123
bits,
group_size,
in_features,
out_features,
bias,
weight_dtype=layer.weight.dtype, # pylint: disable=E1123
weight_dtype=layer.weight.dtype,
)

new_layer.device = device
Expand Down
4 changes: 2 additions & 2 deletions auto_round/export/export_to_autoround/export_to_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="gptq:exllamav
out_features = layer.weight.shape[1]
bias = layer.bias is not None and torch.any(layer.bias)

new_layer = QuantLinear(
bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype ##pylint: disable=E1123
new_layer = QuantLinear( # pylint: disable=E1123
bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype
)

new_layer.device = device
Expand Down
251 changes: 251 additions & 0 deletions auto_round/export/export_to_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import json
import os
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.nn as nn

from auto_round.export.register import register_format
from auto_round.utils import check_to_quantized, convert_dtype_torch2str_hf, get_block_names, get_module, logger

# MIT License

# Copyright (c) 2023 MIT HAN Lab

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.


@register_format("auto_awq")
def save_quantized_as_autoawq(output_dir, model_path, inplace=True, **kwargs):
"""Export the model to autogptq format to easily leverage cuda kernel."""
model = kwargs["model"]
weight_config = kwargs["weight_config"]
sym = kwargs["sym"]
bits = kwargs["bits"]
group_size = kwargs["group_size"]
iters = kwargs["iters"]
lr = kwargs["lr"]
minmax_lr = kwargs["minmax_lr"]
enable_minmax_tuning = kwargs["enable_minmax_tuning"]
enable_quanted_input = kwargs["enable_quanted_input"]
scale_dtype = kwargs["scale_dtype"]
tokenizer = kwargs["tokenizer"]
supported_types = kwargs["supported_types"]

logger.info("Saving quantized model to autoawq format")
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
##check module quantized in block, this may have bug for mixed precision quantization
block_name = get_block_names(model)[0]
first_block = get_module(model, block_name)
all_to_quantized = True
modules_in_block_to_quantize = []
for n, m in first_block.named_modules():
is_supported_type = False
for supported_type in supported_types:
if isinstance(m, supported_type):
is_supported_type = True
break
if not is_supported_type:
continue
if not check_to_quantized(m):
all_to_quantized = False
else:
modules_in_block_to_quantize.append(n)
modules_in_block_to_quantize = [modules_in_block_to_quantize]
if all_to_quantized:
modules_in_block_to_quantize = None
yintong-lu marked this conversation as resolved.
Show resolved Hide resolved

if inplace:
compressed_model = model.to("cpu")
else:
compressed_model = copy.deepcopy(model.to("cpu"))

from awq import AutoAWQForCausalLM # pylint: disable=E0401
from awq.modules.linear import WQLinear_GEMM # pylint: disable=E0401
from awq.utils.utils import clear_memory # pylint: disable=E0401

q_linear_module = WQLinear_GEMM
awq_model = AutoAWQForCausalLM.from_pretrained(model_path)
self_modules = awq_model.get_model_layers(compressed_model)
del awq_model # release memory
for i in range(len(self_modules)):
module = self_modules[i]
named_linears = get_named_linears(module)
for name, linear_layer in named_linears.items():
key = get_module_name(compressed_model, linear_layer)
info = weight_config[key]
if not check_to_quantized(info):
continue
info["zp"] = info["zp"].to(torch.float32)
scale, zp = info["scale"], info["zp"]
scale = scale.t().contiguous()
zp = zp.t().contiguous()
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=bits,
group_size=group_size,
init_only=False,
scales=scale,
zeros=zp,
)
linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory()

quant_config = {}
quant_config["quant_method"] = "awq"
quant_config["modules_to_not_convert"] = None
quant_config["version"] = "gemm"
quant_config["iters"] = iters
quant_config["lr"] = lr
quant_config["minmax_lr"] = minmax_lr
quant_config["enable_minmax_tuning"] = enable_minmax_tuning
quant_config["enable_quanted_input"] = enable_quanted_input
yintong-lu marked this conversation as resolved.
Show resolved Hide resolved
quant_config["scale_dtype"] = convert_dtype_torch2str_hf(scale_dtype)
quant_config["sym"] = sym
quant_config["bits"] = bits
quant_config["group_size"] = group_size
quant_config["zero_point"] = not sym

save_quantized(compressed_model, save_dir=output_dir, quant_config=quant_config)


from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint


def save_quantized(
model,
save_dir,
quant_config,
safetensors=True,
shard_size="10GB",
):
save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir

# Save model
class EmptyModule(nn.Module):
def __init__(self):
super(EmptyModule, self).__init__()

def forward(self, x):
return x

# Save model and config files with empty state dict
awq_quant_config = {
"quant_method": "awq",
"zero_point": quant_config["zero_point"],
"group_size": quant_config["group_size"],
"bits": quant_config["bits"],
"version": "gemm",
"modules_to_not_convert": None,
}

model.config.quantization_config = awq_quant_config
model.generation_config.do_sample = True
model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())

# Remove empty state dict
default_paths = [
f"{save_dir}/model.safetensors",
f"{save_dir}/pytorch_model.bin",
]
for path in default_paths:
if os.path.exists(path):
os.remove(path)

# model_name has no extension, add it when saving state_dict
model_name = "model.safetensors" if safetensors else "pytorch_model.bin"

# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(model.state_dict(), max_shard_size=shard_size, weights_name=model_name)

for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))

# save shard index
if index is not None:
with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
file.write(json.dumps(index, indent=4))

# save quantize_config
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
json.dump(quant_config, f, indent=2)


def get_named_linears(module):
"""Get the name, linear_op pairs of a given module.

Args:
module: A module to be searched.
"""
return {name: m for name, m in module.named_modules() if isinstance(m, torch.nn.Linear)}


def set_op_by_name(layer, name, new_module):
levels = name.split(".")
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels) - 1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)


def get_module_name(model, module_to_find):
"""Get the name of a given module in a model.

Args:
model: The model.
module_to_find: A module to be found.

Returns:
name: The corresponding name of the given module.
"""
for name, module in model.named_modules():
if module is module_to_find:
return name
return None
1 change: 1 addition & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __call__(self, *args, **kwargs):

auto_gptq = LazyImport("auto_gptq")
htcore = LazyImport("habana_frameworks.torch.core")
awq = LazyImport("autoawq")
yintong-lu marked this conversation as resolved.
Show resolved Hide resolved


def is_optimum_habana_available():
Expand Down
2 changes: 2 additions & 0 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def get_library_version(library_name):
model = model.to("cpu")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if args.bits == 4:
yintong-lu marked this conversation as resolved.
Show resolved Hide resolved
autoround.save_quantized(f'{export_dir}-awq', format="auto_awq", inplace=inplace, model_path=args.model_name)

if not args.disable_eval and "fake" in deployment_device: ##support autogptq real eval later
excel_name = f"{output_dir}_result.xlsx"
Expand Down
2 changes: 1 addition & 1 deletion examples/language-modeling/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ auto-gptq
openpyxl
wandb
py-cpuinfo

autoawq
yintong-lu marked this conversation as resolved.
Show resolved Hide resolved