Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support to export gguf q4_0 and q4_1 format #393

Merged
merged 30 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8355347
export gguf
n1ck-guo Dec 24, 2024
dd55003
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 24, 2024
f67219b
q4_0/1 port c++ to python
n1ck-guo Dec 24, 2024
611c4c1
Merge branch 'hengguo/gguf' of https://github.com/intel/auto-round in…
n1ck-guo Dec 24, 2024
ce1c48e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 24, 2024
7ab730b
change to llama.cpp stype and add uint8 store
n1ck-guo Dec 25, 2024
287b5af
abstract
n1ck-guo Dec 25, 2024
49d95a8
merge
n1ck-guo Dec 25, 2024
113532a
update
n1ck-guo Dec 26, 2024
ee66c47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
d395c6b
fix
n1ck-guo Dec 26, 2024
8b13f1f
Merge branch 'hengguo/gguf' of https://github.com/intel/auto-round in…
n1ck-guo Dec 26, 2024
ce2c346
update
n1ck-guo Dec 30, 2024
8bceb3f
default sequence eval
n1ck-guo Dec 30, 2024
722a1d8
modify by comments
n1ck-guo Dec 30, 2024
8712170
update
n1ck-guo Dec 30, 2024
1aa979a
pylint
n1ck-guo Dec 30, 2024
515160d
clean
n1ck-guo Dec 30, 2024
a064c44
pylint
n1ck-guo Dec 30, 2024
fa2328d
fix
n1ck-guo Dec 30, 2024
7906284
update
n1ck-guo Dec 31, 2024
4261191
Merge branch 'main' into hengguo/gguf
n1ck-guo Dec 31, 2024
e525f97
add ut
n1ck-guo Dec 31, 2024
b0f96a0
add cuda ut
n1ck-guo Dec 31, 2024
c7ec3a5
add requirements
n1ck-guo Dec 31, 2024
79c5c5a
format
n1ck-guo Dec 31, 2024
2720287
code scane
n1ck-guo Dec 31, 2024
db15354
update
n1ck-guo Jan 7, 2025
24a68a9
merge main
n1ck-guo Jan 7, 2025
cb67c1a
update
n1ck-guo Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
endianess
12 changes: 9 additions & 3 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@
import sys

def run_eval():
from auto_round.script.llm import setup_eval_parser, eval
args = setup_eval_parser()
eval(args)
if "--non_sequence" in sys.argv:
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
sys.argv.remove("--non_sequence")
from auto_round.script.llm import setup_eval_parser, eval
args = setup_eval_parser()
eval(args)
else:
from auto_round.script.llm import setup_eval_parser, eval_sequence
args = setup_eval_parser()
eval_sequence(args)

def run():
if "--eval" in sys.argv:
Expand Down
5 changes: 5 additions & 0 deletions auto_round/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ def _save_quantized_as_autoawq(*args, **kwargs):
from auto_round.export.export_to_awq.export import save_quantized_as_autoawq

return save_quantized_as_autoawq(*args, **kwargs)

@register_format("gguf")
def _save_quantized_as_autoawq(*args, **kwargs):
from auto_round.export.export_to_gguf.export import save_quantized_as_gguf
return save_quantized_as_gguf(*args, **kwargs)
4,442 changes: 4,442 additions & 0 deletions auto_round/export/export_to_gguf/convert.py

Large diffs are not rendered by default.

78 changes: 78 additions & 0 deletions auto_round/export/export_to_gguf/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import shutil
import torch
from .convert import Model
from auto_round.utils import logger
from pathlib import Path

import gguf # pylint: disable=E0401


FTYPE_MAP: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"q4_0": gguf.LlamaFileType.MOSTLY_Q4_0,
"q4_1": gguf.LlamaFileType.MOSTLY_Q4_1,
"q4_k": gguf.LlamaFileType.MOSTLY_Q4_K_S,
"auto": gguf.LlamaFileType.GUESSED,
}

def save_quantized_as_gguf(output_dir, backend="gguf:q4_0", **kwargs):
"""Export the model to gguf format."""

n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
model = kwargs["model"]
tokenizer = kwargs.get("tokenizer", None)
config = model.config

tmp_work_dir = Path(os.path.join(output_dir, 'tmp_dir'))
if tokenizer is not None:
tokenizer.save_pretrained(tmp_work_dir)
config.save_pretrained(tmp_work_dir)

with torch.inference_mode():
wenhuach21 marked this conversation as resolved.
Show resolved Hide resolved
hparams = Model.load_hparams(tmp_work_dir)
model_architecture = hparams["architectures"][0]
try:
model_class = Model.from_model_architecture(model_architecture)
except NotImplementedError:
logger.error(f"Model {model_architecture} is not supported")
sys.exit(1)
model_class = Model.from_model_architecture(model_architecture)
model_name = model.name_or_path.split('/')
if len(model_name[-1]) == 0:
model_name = model_name[-2]
else:
model_name = model_name[-1]

output_type = backend.split(":")[-1]
output_type = FTYPE_MAP.get(output_type.lower())

layer_config = kwargs.get("layer_config")

model_instance = model_class(model, dir_model=tmp_work_dir, ftype=output_type, fname_out=Path(output_dir),
layer_config=layer_config, is_big_endian=False, model_name=model_name,
split_max_tensors=False, split_max_size=0, dry_run=False,
small_first_shard=False)
model_instance.write()
logger.info(f"Model successfully exported to {model_instance.fname_out}")
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved

shutil.rmtree(tmp_work_dir, ignore_errors=True)

return model
100 changes: 100 additions & 0 deletions auto_round/export/export_to_gguf/quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gguf # pylint: disable=E0401
import numpy as np

QK_K = 256
GGML_QUANT_SIZES = {
"bf16": (1, 2),
"q4_0": (32, 2 + 16),
"q4_1": (32, 2 + 2 + 16),
"q4_k": (256, 2 + 2 + QK_K // 2 + 12),
}

GGML_QUANT_BLOCK = {}

def register_block(name):
def register(cls):
GGML_QUANT_BLOCK[name] = cls
return cls
return register


def ggml_quant(data: np.array, ggml_type, scale = None, zp = None):
block_size, type_size = GGML_QUANT_SIZES[ggml_type]

data = data.astype(np.float32, copy=False)
shape = data.shape
n_blocks = data.size // block_size
blocks = data.reshape((n_blocks, block_size))

new_data = GGML_QUANT_BLOCK[ggml_type](blocks, scale, zp)
new_data = new_data.reshape(*shape[:-1], shape[-1] // block_size * type_size)
return new_data

@register_block("bf16")
def bf16_quant_block(blocks: np.array, scale = None, zp = None):
n = blocks.view(np.uint32)
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
# round to nearest even
n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.uint16).view(np.uint8)

@register_block("q4_0")
def q4_0_quant_block(blocks: np.array, scale = None, zp = None):
if scale is not None:
wenhuach21 marked this conversation as resolved.
Show resolved Hide resolved
d = scale.reshape((-1,1))
else:
imax = abs(blocks).argmax(axis=-1, keepdims=True)
max = np.take_along_axis(blocks, imax, axis=-1)
d = max / -8
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)

qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)

n_blocks = blocks.shape[0]
block_size = GGML_QUANT_SIZES["q4_0"][0]
qs = qs.reshape((n_blocks, 2, block_size // 2))
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))

d = d.astype(np.float16).view(np.uint8)

return np.concatenate([d, qs], axis=-1)


@register_block("q4_1")
def q4_1_quant_block(blocks: np.array, scale = None, zp = None):
if scale is not None:
d = scale.reshape((-1,1))
min = zp.reshape((-1,1))
else:
max = blocks.max(axis=-1, keepdims=True)
min = blocks.min(axis=-1, keepdims=True)
d = (max - min) / 15
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)

qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15)

n_blocks = blocks.shape[0]
block_size = GGML_QUANT_SIZES["q4_1"][0]
qs = qs.reshape((n_blocks, 2, block_size // 2))
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))

d = d.astype(np.float16).view(np.uint8)
m = min.astype(np.float16).view(np.uint8)
return np.concatenate([d, m, qs], axis=-1)
77 changes: 75 additions & 2 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def tune(args):
if args.format is None:
args.format = "auto_round"
supported_formats = ["auto_round", "auto_gptq", "auto_awq", "auto_round:auto_gptq", "auto_round:auto_awq",
"auto_gptq:marlin", "itrex", "iterx_xpu", "fake"]
formats = args.format.replace(' ', '').split(",")
"auto_gptq:marlin", "itrex", "iterx_xpu", "fake", "gguf:q4_0", "gguf:q4_1"]
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
formats = args.format.lower().replace(' ', '').split(",")
for format in formats:
if format not in supported_formats:
raise ValueError(f"{format} is not supported, we only support {supported_formats}")
Expand Down Expand Up @@ -312,6 +312,23 @@ def tune(args):
from auto_round.utils import detect_device, get_library_version, detect_device_count
from auto_round.utils import logger

if args.format in ["gguf:q4_0", "gguf:q4_1"]:
args.bits = 4
if args.act_bits <= 8:
logger.warning(f"{args.format} not support for activation quantization. Reset act_bits to 16.")
args.act_bits = 16
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
if args.group_size != 32:
logger.warning(f"{args.format} not support for group_size: {args.group_size}. "
"Reset group_size to 32.")
args.group_size = 32
if args.format.endswith("_0"):
args.asym = False
if args.format.endswith("_1"):
args.asym = True

if "gguf" in args.format:
args.disable_eval = True

model_name = args.model
if model_name[-1] == "/":
model_name = model_name[:-1]
Expand Down Expand Up @@ -553,3 +570,59 @@ def eval(args):

from lm_eval.utils import make_table # pylint: disable=E0401
print(make_table(res))

def eval_sequence(args):
import os
devices = args.device.replace(" ", "").split(',')
parallelism = False

if all(s.isdigit() for s in devices):
if "CUDA_VISIBLE_DEVICES" in os.environ:
current_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
current_visible_devices = current_visible_devices.split(',')
indices = [int(device) for device in devices]
try:
pick_device = [current_visible_devices[i] for i in indices]
except:
raise ValueError(
n1ck-guo marked this conversation as resolved.
Show resolved Hide resolved
"Invalid '--device' value: It must be smaller than the number of available devices. "
"For example, with CUDA_VISIBLE_DEVICES=4,5, "
"--device 0,1 is valid, but --device 4,5 is not supported.")
visible_devices = ','.join(pick_device)
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
else:
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.device = ",".join(map(str, range(len(devices))))
devices = args.device.replace(" ", "").split(',')
if len(devices) > 1:
parallelism = True
device_str = None
elif args.device == "auto":
device_str = None
parallelism = True
else:
device_str = detect_device(args.device.replace(" ", ""))

from auto_round.eval.evaluation import simple_evaluate

model_args = f"pretrained={args.model},trust_remote_code={not args.disable_trust_remote_code}"
if parallelism:
model_args += ",parallelize=True"
if isinstance(args.tasks, str):
tasks = args.tasks.split(',')

from lm_eval.utils import make_table # pylint: disable=E04
all_res = {}
for task in tasks:
res = simple_evaluate(
model="hf",
model_args=model_args,
tasks=task,
device=device_str,
batch_size=args.eval_bs)
if all_res == {}:
all_res = res
else:
for key in ['results', "versions", "n-shot", "higher_is_better"]:
all_res[key].update(res[key])
print(make_table(all_res))
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.codespell]
ignore-words = ".azure-pipelines/scripts/codeScan/codespell/autoround_dict.txt"
Loading