Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] Support 3bits & 4bits GPTQ for gpt-j (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 authored Jan 31, 2024
1 parent bb80273 commit 4c90707
Show file tree
Hide file tree
Showing 4 changed files with 355 additions and 81 deletions.
137 changes: 67 additions & 70 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@
GGML_QK4_1_TYPE = 3
GGML_QJBLAS_TYPE = 19

# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
assert tensor.shape[1] % GGML_QK4_0 == 0
Expand Down Expand Up @@ -175,25 +197,64 @@ def unpack_weight(qweight, scales, qzeros, q_config):
raise ValueError(f"Unsupported q_config without quant_method: {q_config}")
quant_method = q_config["quant_method"]
if quant_method == "gptq":
return unpack_gptq_weight(qweight, scales, qzeros, q_config)
qbits = q_config["bits"]
if qbits == 4:
return unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config)
elif qbits == 3:
return unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config)

return ValueError(f"Unsupported q_config[bits]: {qbits}")

if quant_method == "awq":
return unpack_awq_weight(qweight, scales, qzeros, q_config)
raise ValueError(f"Unsupported quant_method: {quant_method}")


def unpack_gptq_weight(qweight, scales, qzeros, q_config):
def unpack_gptq_weight_4bits(qweight, scales, qzeros, q_config):
group_size = q_config['group_size']
bits = q_config['bits']
wf = torch.tensor([[ 0, 4, 8, 12, 16, 20, 24, 28]], dtype=torch.int32)
s32_bits = 32

assert bits == 4
# Int32 can store 8 * 4bits data. This is the offset for each data.
wf = torch.tensor(list(range(0, s32_bits, bits)), dtype=torch.int32).unsqueeze(0)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

zeros = zeros + 1
zeros = zeros.reshape(scales.shape)


weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(weight,(2 ** bits) - 1, out=weight)

return weight, scales, zeros

def unpack_gptq_weight_3bits(qweight, scales, qzeros, q_config):
print("unpack_gptq_weight_3bits... ", end='')
group_size = q_config['group_size']
bits = q_config['bits']
s32_bits = 32

assert bits == 3
# Int32 can only store 10 * 3bits data. This is the offset for each data.
wf = torch.tensor([[ i for i in range(0, s32_bits - bits, bits)]], dtype=torch.int32)
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
wf.unsqueeze(0)).to(torch.int16 if bits == 8 else torch.int8)
torch.bitwise_and(zeros, (2 ** bits) - 1, out=zeros)

zeros = zeros + 1
zeros = zeros.reshape(zeros.shape[0], -1)
zeros = zeros[:,:scales.shape[1]]

weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
wf.unsqueeze(-1)).to(torch.int16 if bits == 8 else torch.int8)

weight = weight.reshape(-1, weight.shape[-1])
input_feature = group_size * scales.shape[0]
weight = weight[:input_feature,:]

torch.bitwise_and(weight,(2 ** bits) - 1, out=weight)

return weight, scales, zeros
Expand Down Expand Up @@ -254,7 +315,7 @@ def load_quantized_model(model_path):
return model, config, config["quantization_config"]


def convert_fp32_tensor(src_name, dst_name, model, fout):
def convert_to_fp32_tensor(src_name, dst_name, model, fout):
v = model[src_name]
shape = v.shape
# print("Processing non-Q4 variable: " + src_name +
Expand Down Expand Up @@ -329,7 +390,6 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
qweight = model[f"{src_name}.qweight"]

weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
# import pdb; pdb.set_trace()
# weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2])
# num_itr = g_idx.shape[0]//x.shape[-1]
if 'desc_act' in q_config and q_config['desc_act']:
Expand All @@ -351,66 +411,3 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h
weight.numpy().tofile(fout)

print(f"converting {dst_name} qauntized tensor to fp32 tensor")


def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
qzeros = model[f"{src_name}.qzeros"]
zeros = qzeros_to_zeros(qzeros)
scales = model[f"{src_name}.scales"]
qweight = model[f"{src_name}.qweight"]

int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
int_weight = int_weight.view(-1,int_weight.shape[-1])

# permute_func for llama-like model
if permute_func:
int_weight = permute_func(int_weight.t(), n_head, n_head_kv).t().contiguous()
gptq_scales = permute_func(gptq_scales.t(), n_head, n_head_kv).t().contiguous()
gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous()

# shuffle weight in GPTQ when act order is on
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
int_weight2 = int_weight.clone()
group_size=q_config['group_size']
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
if group_idx not in group_dict:
target_idx = group_idx * group_size
group_dict[group_idx] = 0
else:
group_dict[group_idx] = group_dict[group_idx] + 1
target_idx = group_idx * group_size + group_dict[group_idx]
int_weight2[target_idx] = int_weight[i]
int_weight = int_weight2

shape = int_weight.shape
write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)

if q_config['bits'] == 4:
int_weight = (int_weight - 8) * 16
gptq_scales = gptq_scales / 16
gptq_zeros = (gptq_zeros - 8) * 16
dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
int_weight = np.ascontiguousarray(int_weight.numpy())
gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
if q_config['sym']:
gptq_zeros = np.empty(0, dtype=np.int8)
else:
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = np.ascontiguousarray(g_idx.numpy())
else:
g_idx = np.empty(0, dtype=np.int32)

# pack int weight in bestla format
byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
weight_dtype="int4" if q_config['bits'] == 4 else "int8",
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
210 changes: 210 additions & 0 deletions neural_speed/convert/convert_quantized_gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import argparse
from common import *
from tqdm import tqdm
from transformers import AutoTokenizer


def permute_func(weights, n_head: int, n_head_kv: int):
if n_head_kv is not None and n_head != n_head_kv:
n_head //= n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2,
*weights.shape[1:]).swapaxes(1, 2).reshape(weights.shape))

def convert_to_qx_bestla_tensor(src_name, dst_name, model, fout, q_config):
# unpack weight and repack into 3bits / 4bits BestLA format
import neural_speed.llama_cpp as cpp_model
if ".weight" in src_name:
src_name = src_name.replace(".weight", "")
qzeros = model[f"{src_name}.qzeros"]
zeros = qzeros_to_zeros(qzeros)
scales = model[f"{src_name}.scales"]
qweight = model[f"{src_name}.qweight"]

int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
int_weight = int_weight.view(-1,int_weight.shape[-1])

# shuffle weight in GPTQ when act order is on
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
int_weight2 = int_weight.clone()
group_size=q_config['group_size']
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
if group_idx not in group_dict:
target_idx = group_idx * group_size
group_dict[group_idx] = 0
else:
group_dict[group_idx] = group_dict[group_idx] + 1
target_idx = group_idx * group_size + group_dict[group_idx]
int_weight2[target_idx] = int_weight[i]
int_weight = int_weight2

shape = int_weight.shape
write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)

# INC stores sig-int4 value as u4(range 0~15, they add a offset),
# BesTLA requires s4_clip((-8,7)*16), so we sub the offset and then mul 16.
# Int3 is the same as int4, but offset=4, mul scale==32.
weight_dtype = "int8"
if q_config['bits'] == 4:
int_weight = (int_weight - 8) * 16
gptq_scales = gptq_scales / 16
gptq_zeros = (gptq_zeros - 8) * 16
weight_dtype = "int4"
elif q_config['bits'] == 3:
int_weight = (int_weight - 4) * 32
gptq_scales = gptq_scales / 32
gptq_zeros = (gptq_zeros - 4) * 32
weight_dtype = "int3"

dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
int_weight = np.ascontiguousarray(int_weight.numpy())
gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
if q_config['sym']:
gptq_zeros = np.empty(0, dtype=np.int8)
else:
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = np.ascontiguousarray(g_idx.numpy())
else:
g_idx = np.empty(0, dtype=np.int32)

# repack int weight in BesTLA format
byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
weight_dtype=weight_dtype,
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")


def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

out_path = args.outfile.as_posix()
model_path = args.model.as_posix()

model, config, quantize_config = load_quantized_model(model_path)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
fout = open(out_path, "wb")

# 1. write hparams
hparams = config
n_layer = hparams["n_layer"]
fout.write(b"ggjt"[::-1]) #0x67676d6c)) # magic: ggml in hex
values = [
1, # file version
hparams["vocab_size"],
hparams["n_embd"],
hparams["n_embd"] // hparams["n_head"],
hparams["n_head"],
hparams.get("n_head_kv", 0), # multi-query attention
hparams["n_layer"],
hparams["rotary_dim"],
0
]
fout.write(struct.pack("i" * len(values), *values))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)

fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms norm eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))

# 2. vocab
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

encoder = tokenizer.vocab
# Add added_tokens (special tokens) to the encoder
encoder_added = tokenizer.get_added_vocab()

for i, key in enumerate(sorted(encoder, key=encoder.get)):
# for key in encoder:
text = bytearray([byte_decoder[c] for c in key])
fout.write(struct.pack("i", len(text)))
fout.write(text)
if key not in encoder_added:
fout.write(struct.pack("f", 0.0 - i))
else:
fout.write(struct.pack("f", -10000))

# 3. write tensors
list_vars = model

convert_to_fp32_tensor("transformer.wte.weight", "transformer.wte.weight", list_vars, fout)
convert_to_fp32_tensor("transformer.ln_f.weight", "transformer.ln_f.weight", list_vars, fout)
convert_to_fp32_tensor("transformer.ln_f.bias", "transformer.ln_f.bias", list_vars, fout)
convert_to_fp32_tensor("lm_head.bias", "lm_head.bias", list_vars, fout)
convert_to_fp32_tensor("lm_head.weight", "lm_head.weight", list_vars, fout)

for i in tqdm(range(n_layer), desc="Processing layers"):
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.q_proj.weight",
f"transformer.h.{i}.attn.q_proj.weight", list_vars, fout, quantize_config)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.k_proj.weight",
f"transformer.h.{i}.attn.k_proj.weight", list_vars, fout, quantize_config)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.v_proj.weight",
f"transformer.h.{i}.attn.v_proj.weight", list_vars, fout, quantize_config)

convert_to_qx_bestla_tensor(f"transformer.h.{i}.attn.out_proj.weight",
f"transformer.h.{i}.attn.out_proj.weight", list_vars, fout, quantize_config)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.fc_in.weight",
f"transformer.h.{i}.mlp.fc_in.weight", list_vars, fout, quantize_config)
convert_to_qx_bestla_tensor(f"transformer.h.{i}.mlp.fc_out.weight",
f"transformer.h.{i}.mlp.fc_out.weight", list_vars, fout, quantize_config)

convert_to_fp32_tensor(f"transformer.h.{i}.mlp.fc_in.bias",
f"transformer.h.{i}.mlp.fc_in.bias", list_vars, fout)
convert_to_fp32_tensor(f"transformer.h.{i}.mlp.fc_out.bias",
f"transformer.h.{i}.mlp.fc_out.bias", list_vars, fout)

convert_to_fp32_tensor(f"transformer.h.{i}.ln_1.weight",
f"transformer.h.{i}.ln_1.weight", list_vars, fout)
convert_to_fp32_tensor(f"transformer.h.{i}.ln_1.bias",
f"transformer.h.{i}.ln_1.bias", list_vars, fout)


fout.close()
print(f"Success! saved as {out_path}")


if __name__ == '__main__':
main()
Loading

0 comments on commit 4c90707

Please sign in to comment.