Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 94e74d7

Browse files
authored
[Model Enabling] Support ChatGLM3 (#182)
1 parent 20fd168 commit 94e74d7

File tree

15 files changed

+554
-36
lines changed

15 files changed

+554
-36
lines changed

docs/prompt_template.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Prompt template
2+
3+
This document will show some examples to introduce how to correctly use prompt templates in Neural Speed and [ITREX](https://github.com/intel/intel-extension-for-transformers).
4+
5+
For the base model (without SFT for pre-training), prompt can be directly encoded into token ids without adding any special prefix or suffix token. But for the chat model, we need some prompt templates to generate correct and human understandable words. The reason is that these models are usually trained with specific prompt templates.
6+
7+
## Chat with ChatGLM3:
8+
```python
9+
from transformers import AutoTokenizer, TextStreamer
10+
from neural_speed import Model
11+
12+
prompt = "你好"
13+
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
14+
inputs = tokenizer.build_chat_input(prompt)['input_ids']
15+
model = Model()
16+
model.init_from_bin(args.model_name, gguf_path)
17+
outputs = model.generate(inputs, max_new_tokens=300, do_sample=True)
18+
words = tokenizer.decode(outputs[0])
19+
```
20+
21+
## Chat with LLaMA2:
22+
23+
```python
24+
from transformers import AutoTokenizer, TextStreamer
25+
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
26+
27+
# Please change to local path to model, llama2 does not support online conversion, currently.
28+
model_name = "meta-llama/Llama-2-7b-chat-hf"
29+
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
30+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
31+
streamer = TextStreamer(tokenizer)
32+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
33+
34+
while True:
35+
prompt = input("> ").strip()
36+
if prompt == "quit":
37+
break
38+
b_prompt = "[INST]{}[/INST]".format(prompt) # prompt template for llama2
39+
inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
40+
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True)
41+
```
42+
43+
## Chat with ChatGLM2:
44+
```python
45+
from transformers import AutoTokenizer, TextStreamer
46+
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
47+
48+
model_name = "THUDM/chatglm2-6b" # or local path to model
49+
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
50+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
51+
streamer = TextStreamer(tokenizer)
52+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
53+
54+
while True:
55+
prompt = input("> ").strip()
56+
if prompt == "quit":
57+
break
58+
prompt = tokenizer.build_prompt(prompt) # prompt template for chatglm2
59+
inputs = tokenizer([prompt], return_tensors="pt").input_ids
60+
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True, n_keep=2)
61+
```
62+
63+
## Chat with Qwen:
64+
```python
65+
from transformers import AutoTokenizer, TextStreamer
66+
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
67+
68+
model_name = "Qwen/Qwen-7B-Chat" # or local path to model
69+
woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
70+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
71+
streamer = TextStreamer(tokenizer)
72+
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
73+
74+
while True:
75+
prompt = input("> ").strip()
76+
if prompt == "quit":
77+
break
78+
prompt = "\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n".format(prompt) # prompt template for qwen
79+
inputs = tokenizer([prompt], return_tensors="pt").input_ids
80+
outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True)
81+
```

docs/supported_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ Neural Speed supports the following models:
219219
</tr>
220220
<tr>
221221
<td><a href="https://huggingface.co/THUDM/chatglm-6b" target="_blank" rel="noopener noreferrer">ChatGLM-6B</a>,
222-
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a></td>
222+
<a href="https://huggingface.co/THUDM/chatglm2-6b" target="_blank" rel="noopener noreferrer">ChatGLM2-6B</a>,
223+
<a href="https://huggingface.co/THUDM/chatglm3-6b" target="_blank" rel="noopener noreferrer">ChatGLM3-6B</a></td>
223224
<td>✅</td>
224225
<td> </td>
225226
<td> </td>

neural_speed/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
class Model:
27+
2728
def __init__(self):
2829
self.module = None
2930
self.model = None
@@ -55,7 +56,7 @@ def __import_package(self, model_type):
5556
import neural_speed.bloom_cpp as cpp_model
5657
elif model_type == "chatglm":
5758
import neural_speed.chatglm_cpp as cpp_model
58-
elif model_type == "chatglm2":
59+
elif model_type == "chatglm2" or model_type == "chatglm3":
5960
import neural_speed.chatglm2_cpp as cpp_model
6061
elif model_type == "baichuan":
6162
import neural_speed.baichuan_cpp as cpp_model
@@ -85,6 +86,11 @@ def get_model_type(model_config):
8586
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
8687
model_type = "chatglm2"
8788

89+
# For ChatGLM3
90+
if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
91+
# due to the same model architecture.
92+
model_type = "chatglm2"
93+
8894
# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
8995
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
9096
model_type = "falcon"
@@ -200,7 +206,7 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
200206

201207
def get_max_seq_length():
202208
config = self.config.to_dict()
203-
# chatglm2, bloom
209+
# chatglm2, bloom, chatglm3
204210
if 'seq_length' in config:
205211
return config['seq_length']
206212
# qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi

neural_speed/application/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ compile_quant(quant_bloom quant_model.cpp bloom bloom)
6565

6666
compile_quant(quant_chatglm quant_model.cpp chatglm chatglm)
6767
compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)
68+
compile_quant(quant_chatglm3 quant_model.cpp chatglm2 chatglm2)
6869
compile_quant(quant_baichuan quant_model.cpp baichuan baichuan)
6970
compile_quant(quant_mistral quant_model.cpp mistral llama)
7071
compile_quant(quant_mixtral quant_model.cpp mixtral llama)
@@ -97,7 +98,7 @@ set(mymap_phi 16)
9798
set(mymap_stablelm 17)
9899
set(mymap_whisper 18)
99100
set(mymap_mixtral 19)
100-
101+
set(mymap_chatglm3 20)
101102

102103

103104
function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB)
@@ -128,6 +129,7 @@ compile_run(run_starcoder main_run.cpp main_pybind.cpp starcoder starcoder)
128129
compile_run(run_opt main_run.cpp main_pybind.cpp opt opt)
129130
compile_run(run_bloom main_run.cpp main_pybind.cpp bloom bloom)
130131
compile_run(run_chatglm2 main_run.cpp main_pybind.cpp chatglm2 chatglm2)
132+
compile_run(run_chatglm3 main_run.cpp main_pybind.cpp chatglm3 chatglm3)
131133
compile_run(run_chatglm main_run.cpp main_pybind.cpp chatglm chatglm)
132134
compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan)
133135
compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama)

neural_speed/application/main_pybind.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,10 @@ PYBIND11_MODULE(whisper_cpp, m)
921921

922922
PYBIND11_MODULE(mixtral_cpp, m)
923923

924+
#elif MODEL_NAME_ID == 20
925+
926+
PYBIND11_MODULE(chatglm3_cpp, m)
927+
924928
#endif
925929
{
926930
m.doc() = "cpp model python binding";

neural_speed/application/main_run.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ int main(int argc, char** argv) { // NOLINT
240240
std::string prompt = build_prompt_glm2(prompts);
241241
embd_inp = ::model_tokenize(ctx, prompt, false);
242242
embd_inp.insert(embd_inp.begin(), {64790, 64792}); // special prefix
243-
} else if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_BAICHUAN) {
243+
} else if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_BAICHUAN ||
244+
params.model_arch == MODEL_CHATGLM3) {
244245
for (auto& i : params.ids) {
245246
embd_inp.emplace_back(i);
246247
}
@@ -646,7 +647,7 @@ int main(int argc, char** argv) { // NOLINT
646647

647648
// display text
648649
if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_CHATGLM2 ||
649-
params.model_arch == MODEL_BAICHUAN) {
650+
params.model_arch == MODEL_BAICHUAN || params.model_arch == MODEL_CHATGLM3) {
650651
static bool is_prompt = true;
651652
if (input_echo) {
652653
if (is_prompt == true) {

neural_speed/convert/common.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,56 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
9090
return tensor
9191

9292

93+
def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor:
94+
# equivalent to ggml_quantize_q8_0 in ggml.c
95+
assert tensor.shape[1] % GGML_QK8_0 == 0
96+
tensor = tensor.view(-1, GGML_QK8_0)
97+
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
98+
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
99+
# add scale into each block
100+
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
101+
return tensor
102+
103+
104+
def quantize_q5_0(tensor: torch.Tensor) -> torch.Tensor:
105+
# equivalent to ggml_quantize_q5_0 in ggml.c
106+
assert tensor.shape[1] % GGML_QK5_0 == 0
107+
tensor = tensor.view(-1, GGML_QK5_0)
108+
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
109+
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
110+
scale = max_values / -16
111+
tensor = (tensor / scale + 16).round().clamp(min=0, max=31).char()
112+
qs = (tensor[:, :16] & 0x0F) | (tensor[:, 16:] << 4)
113+
qh = torch.zeros(tensor.shape[:-1], dtype=torch.int32)
114+
for i in range(32):
115+
qh |= ((tensor[:, i] & 0x10) >> 4).int() << i
116+
117+
# add scale into each block
118+
tensor = torch.cat((scale.half().view(torch.int8), qh[..., None].view(torch.int8), qs), dim=-1)
119+
return tensor
120+
121+
122+
def quantize_q5_1(tensor: torch.Tensor) -> torch.Tensor:
123+
# equivalent to ggml_quantize_q5_1 in ggml.c
124+
assert tensor.shape[1] % GGML_QK5_1 == 0
125+
tensor = tensor.view(-1, GGML_QK5_1)
126+
min_vals = tensor.min(dim=-1, keepdim=True).values
127+
max_vals = tensor.max(dim=-1, keepdim=True).values
128+
scale = (max_vals - min_vals) / ((1 << 5) - 1)
129+
tensor = ((tensor - min_vals) / scale).round().clamp(min=0, max=31).char()
130+
qs = (tensor[:, :16] & 0x0F) | (tensor[:, 16:] << 4)
131+
qh = torch.zeros(tensor.shape[:-1], dtype=torch.int32)
132+
for i in range(32):
133+
qh |= ((tensor[:, i] & 0x10) >> 4).int() << i
134+
135+
# add scale & min into each block
136+
tensor = torch.cat(
137+
(scale.half().view(torch.int8), min_vals.half().view(torch.int8), qh[..., None].view(torch.int8), qs), dim=-1)
138+
return tensor
139+
140+
93141
class SentencePieceVocab:
142+
94143
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
95144
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
96145
added_tokens: Dict[str, int]

0 commit comments

Comments
 (0)