Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support autoawq format #115

Merged
merged 48 commits into from
Aug 5, 2024
Merged
Changes from 36 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
c7770af
add file
yintong-lu May 21, 2024
008899d
add file
yintong-lu May 21, 2024
49deb12
update
yintong-lu May 21, 2024
b1b729d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
9e6cf36
update
yintong-lu May 23, 2024
86b9b96
fix import error
yintong-lu May 23, 2024
10a91e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
4807994
minor change
yintong-lu May 23, 2024
5822543
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
991b5e9
update
yintong-lu May 23, 2024
4051f55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
f966812
fix pylint error
yintong-lu May 23, 2024
5a1188c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
2942bf9
Merge branch 'main' into lyt/autoawq
yintong-lu May 24, 2024
9c4f73b
minor fix
yintong-lu May 24, 2024
05d6e1d
fix pylint issue
yintong-lu May 24, 2024
72b48b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
40e8ce7
update import of awq
yintong-lu May 24, 2024
55b05e6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
07c0f18
fix import error
yintong-lu May 24, 2024
8184d84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 24, 2024
02687ac
modify code
yintong-lu May 27, 2024
b4b6fa3
Merge branch 'main' into lyt/autoawq
yintong-lu May 27, 2024
6d6b8b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2024
e91fb20
fix doc typo
yintong-lu May 27, 2024
a775b21
fix mixtral issue
yintong-lu May 27, 2024
6e76f05
Merge branch 'main' into lyt/autoawq
yintong-lu May 27, 2024
fc39b70
mv awq to autoround format, fix mixtral issue and follow autoround fo…
yintong-lu May 28, 2024
3480b8d
minor fix
yintong-lu May 28, 2024
406410c
fix conflicts
yintong-lu May 29, 2024
949a262
fix conflicts
yintong-lu Jun 11, 2024
81576c5
mv to autoround format
yintong-lu Jun 14, 2024
1267b5f
Merge branch 'main' into lyt/autoawq
yintong-lu Jun 25, 2024
a5ceb0b
minor fix
yintong-lu Jun 25, 2024
7d505a0
typo
yintong-lu Jun 25, 2024
56e60d7
remove comments
yintong-lu Jun 25, 2024
3385eb9
update comments
yintong-lu Jun 28, 2024
5ec30d4
move awq to autoround format evaluation
yintong-lu Jul 3, 2024
d5537f4
pylint error fixing
yintong-lu Jul 3, 2024
821a1b9
fix conflicts, and add awq format
yintong-lu Jul 24, 2024
7980a58
fix conflicts
yintong-lu Jul 24, 2024
5b271ef
Merge branch 'main' into lyt/autoawq
yintong-lu Jul 24, 2024
1f87cad
add ut
yintong-lu Jul 24, 2024
c888e12
add requirement
yintong-lu Jul 24, 2024
1bc88db
add coverage test waiver
yintong-lu Jul 24, 2024
bee724a
minor change
yintong-lu Jul 24, 2024
949c48a
refine code to decrease number of branches
yintong-lu Jul 25, 2024
7794386
add ut, fix minor issues
yintong-lu Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 96 additions & 15 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def dynamic_QuantLienar_for_packing(backend, bits, group_size):
)
return QuantLinear
##export all use trition, inference use exllamav2
elif "awq" in backend:
try:
from awq.modules.linear import WQLinear_GEMM # pylint: disable=E0401
return WQLinear_GEMM
except:
logger.error("autoawq is required. Please install it to support auto_awq format.")
yintong-lu marked this conversation as resolved.
Show resolved Hide resolved
return
elif "autoround" in backend or "auto-round" in backend or "auto_round" in backend:
from auto_round_extension.cuda.qliner_triton import QuantLinear
return QuantLinear
Expand All @@ -103,11 +110,14 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exl
model = copy.deepcopy(model.to("cpu"))
layer_names_in_block = get_layer_names_in_block(model)

modules_to_not_convert = []
weight_config = kwargs["weight_config"]
for name in weight_config.keys():

config = kwargs["weight_config"][name]
if config["data_type"] != "int" and config["bits"] >= 16:
if "awq" in backend:
modules_to_not_convert.append(name)
continue
logger.info(f"packing {name}")

Expand All @@ -130,21 +140,40 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exl
out_features = layer.weight.shape[1]
bias = layer.bias is not None and torch.any(layer.bias)

new_layer = QuantLinear( ##pylint: disable=E1123
bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype
)
if "awq" not in backend:
new_layer = QuantLinear( ##pylint: disable=E1123
bits, group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype
)

new_layer.device = device
set_module(model, name, new_layer)
qlayer = new_layer
scale = weight_config[name]["scale"]
zero = weight_config[name]["zp"]
# so far can only pack layer on CPU
qlayer.to("cpu")
##force to float32 to be compatible with torch 2.0
layer, scale, zero = layer.to("cpu"), scale.to("cpu"), zero.to("cpu").to(torch.float32)
qlayer.pack(layer, scale, zero, None)
qlayer.to(device)
new_layer.device = device
set_module(model, name, new_layer)
qlayer = new_layer
scale = weight_config[name]["scale"]
zero = weight_config[name]["zp"]
# so far can only pack layer on CPU
qlayer.to("cpu")
##force to float32 to be compatible with torch 2.0
layer, scale, zero = layer.to("cpu"), scale.to("cpu"), zero.to("cpu").to(torch.float32)
qlayer.pack(layer, scale, zero, None)
qlayer.to(device)
else:
from awq.utils.utils import clear_memory # pylint: disable=E0401
scale, zp = weight_config[name]["scale"].to(torch.float32), weight_config[name]["zp"].to(torch.float32)
scale = scale.t().contiguous()
zp = zp.t().contiguous()
if bits != 4:
logger.error("AutoAWQ format only supports 4-bits quantization.")
qlayer = QuantLinear.from_linear(
linear=layer,
w_bit=bits,
group_size=group_size,
init_only=False,
scales=scale,
zeros=zp,
)
qlayer.to(device)
set_module(model, name, qlayer)
clear_memory()
quantization_config = kwargs["serialization_dict"]
quantization_config["quant_method"] = "intel/auto-round"
quantization_config["backend"] = backend
Expand Down Expand Up @@ -174,10 +203,22 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="autoround:exl
quantization_config["extra_config"] = extra_config
if hasattr(model, "config"):
model.config.quantization_config = quantization_config
if "awq" in backend:
awq_quant_config = {
"quant_method": "awq",
"zero_point": not quantization_config["sym"],
"group_size": quantization_config["group_size"],
"bits": quantization_config["bits"],
"version": "gemm",
"modules_to_not_convert": None if not modules_to_not_convert else modules_to_not_convert,
}
tokenizer = kwargs["tokenizer"]
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
save(model, output_dir)
if "awq" not in backend:
save(model, output_dir)
else:
save_awq(model, output_dir, awq_quant_config=awq_quant_config)


def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_serialization: bool = True):
Expand Down Expand Up @@ -206,3 +247,43 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)



def save_awq(
model: nn.Module,
save_dir: str,
max_shard_size: str = "5GB",
safe_serialization: bool = True,
awq_quant_config: dict = {}
):
"""Save model state dict and configs.

Args:
model (`nn.Module`):
Model to be saved. The model can be wrapped or unwrapped.
save_dir (`str`):
Directory to which to save. Will be created if it doesn't exist.
max_shard_size (`str`, defaults to `"10GB"`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
<Tip warning={true}>

If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
which will be bigger than `max_shard_size`.

</Tip>
safe_serialization (`bool`, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
os.makedirs(save_dir, exist_ok=True)
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
quantization_config = model.config.quantization_config
else:
quantization_config = awq_quant_config
model.config.quantization_config = awq_quant_config
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
config_file = "quantization_config.json"
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(quantization_config, f, indent=2)