Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 750b356

Browse files
authored
Enable Qwen1-5 (#146)
1 parent 6c36f54 commit 750b356

File tree

6 files changed

+120
-34
lines changed

6 files changed

+120
-34
lines changed

docs/supported_models.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,9 @@ Neural Speed supports the following models:
215215
</tr>
216216
<tr>
217217
<td><a href="https://huggingface.co/Qwen/Qwen-7B-Chat" target="_blank" rel="noopener noreferrer">Qwen-7B</a>,
218-
<a href="https://huggingface.co/Qwen/Qwen-14B-Chat" target="_blank" rel="noopener noreferrer">Qwen-14B</a></td>
218+
<a href="https://huggingface.co/Qwen/Qwen-14B-Chat" target="_blank" rel="noopener noreferrer">Qwen-14B</a>,
219+
<a href="https://huggingface.co/Qwen/Qwen1.5-7B-Chat" target="_blank" rel="noopener noreferrer">Qwen1.5-7B</a>,
220+
<a href="https://huggingface.co/Qwen/Qwen1.5-0.5B" target="_blank" rel="noopener noreferrer">Qwen1.5-0.5B</a></td>
219221
<td>✅</td>
220222
<td> </td>
221223
<td> </td>
@@ -358,6 +360,14 @@ Neural Speed supports the following models:
358360
<td>✅</td>
359361
<td></td>
360362
</tr>
363+
<tr>
364+
<td><a href="https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF" target="_blank" rel="noopener noreferrer">TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUFF</a>,
365+
<td>✅</td>
366+
<td>✅</td>
367+
<td>✅</td>
368+
<td>✅</td>
369+
<td></td>
370+
</tr>
361371
<tr>
362372
<td><a href="https://huggingface.co/TheBloke/SOLAR-10.7B-Instruct-v1.0-GGUF" target="_blank" rel="noopener noreferrer">TheBloke/SOLAR-10.7B-Instruct-v1.0-GGUF</td>
363373
<td>✅</td>
@@ -410,7 +420,8 @@ Neural Speed supports the following models:
410420
<td>✅</td>
411421
</tr>
412422
<tr>
413-
<td><a href="https://huggingface.co/Qwen/Qwen-7B-Chat" target="_blank" rel="noopener noreferrer">Qwen-7B-Chat</td>
423+
<td><a href="https://huggingface.co/Qwen/Qwen-7B-Chat" target="_blank" rel="noopener noreferrer">Qwen-7B-Chat</a>,
424+
<a href="https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GGUF" target="_blank" rel="noopener noreferrer">Qwen1.5-7B-Chat-GGUF</a></td>
414425
<td>✅</td>
415426
<td>✅</td>
416427
<td>✅</td>

neural_speed/convert/convert_qwen.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ def main(args_in: Optional[List[str]] = None) -> None:
100100
fout.write(struct.pack("i", hparams["num_attention_heads"]))
101101
fout.write(struct.pack("i", 0)) # multi-query attention
102102
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
103-
fout.write(struct.pack("i", hparams["kv_channels"]))
103+
fout.write(struct.pack("i", hparams["kv_channels"] if "kv_channels" in hparams
104+
else int(hparams["hidden_size"]/hparams["num_attention_heads"])))
104105
fout.write(struct.pack("i", ftype))
105-
fout.write(struct.pack("i", hparams["seq_length"]))
106+
fout.write(struct.pack("i", hparams["seq_length"] if "seq_length" in hparams
107+
else hparams["max_position_embeddings"]))
106108
fout.write(struct.pack("f", 0.0))
107109
fout.write(struct.pack("f", 0.0))
108110
fout.write(struct.pack("i", 0))
@@ -121,9 +123,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
121123
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
122124
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
123125
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
124-
125-
fout.write(struct.pack("i", tokenizer.special_tokens['<|endoftext|>']))
126-
fout.write(struct.pack("i", tokenizer.special_tokens['<|endoftext|>']))
126+
fout.write(struct.pack("i", hparams["bos_token_id"] if hparams["bos_token_id"]
127+
else tokenizer.special_tokens['<|endoftext|>']))
128+
fout.write(struct.pack("i", hparams["eos_token_id"] if hparams["eos_token_id"]
129+
else tokenizer.special_tokens['<|endoftext|>']))
127130
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
128131
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
129132

neural_speed/models/model_utils/gguf.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,18 +231,17 @@ enum llm_arch {
231231
LLM_ARCH_CHATGLM,
232232
LLM_ARCH_CHATGLM2,
233233
LLM_ARCH_PHI,
234+
LLM_ARCH_QWEN2,
234235
LLM_ARCH_UNKNOWN,
235236
};
236237

237238
static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
238-
{LLM_ARCH_LLAMA, "llama"}, {LLM_ARCH_FALCON, "falcon"},
239-
{LLM_ARCH_GPT2, "gpt2"}, {LLM_ARCH_GPTJ, "gptj"},
240-
{LLM_ARCH_GPTNEOX, "gptneox"}, {LLM_ARCH_MPT, "mpt"},
241-
{LLM_ARCH_BAICHUAN, "baichuan"}, {LLM_ARCH_STARCODER, "starcoder"},
242-
{LLM_ARCH_PERSIMMON, "persimmon"}, {LLM_ARCH_REFACT, "refact"},
243-
{LLM_ARCH_BLOOM, "bloom"}, {LLM_ARCH_STABLELM, "stablelm"},
244-
{LLM_ARCH_QWEN, "qwen"}, {LLM_ARCH_CHATGLM, "chatglm"},
245-
{LLM_ARCH_CHATGLM2, "chatglm2"}, {LLM_ARCH_PHI, "phi"}};
239+
{LLM_ARCH_LLAMA, "llama"}, {LLM_ARCH_FALCON, "falcon"}, {LLM_ARCH_GPT2, "gpt2"},
240+
{LLM_ARCH_GPTJ, "gptj"}, {LLM_ARCH_GPTNEOX, "gptneox"}, {LLM_ARCH_MPT, "mpt"},
241+
{LLM_ARCH_BAICHUAN, "baichuan"}, {LLM_ARCH_STARCODER, "starcoder"}, {LLM_ARCH_PERSIMMON, "persimmon"},
242+
{LLM_ARCH_REFACT, "refact"}, {LLM_ARCH_BLOOM, "bloom"}, {LLM_ARCH_STABLELM, "stablelm"},
243+
{LLM_ARCH_QWEN, "qwen"}, {LLM_ARCH_CHATGLM, "chatglm"}, {LLM_ARCH_CHATGLM2, "chatglm2"},
244+
{LLM_ARCH_PHI, "phi"}, {LLM_ARCH_QWEN2, "qwen2"}};
246245

247246
struct gguf_tensor_info {
248247
struct gguf_str name;

neural_speed/models/qwen/qwen.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
102102
const int n_vocab = hparams.n_vocab;
103103
const int n_rot = hparams.n_rot;
104104
const int head_dim = n_embd / n_head;
105+
int qwen_version = 0;
106+
if (hparams.max_seq_len == 8192) {
107+
qwen_version = 1;
108+
} else {
109+
qwen_version = 2;
110+
}
105111

106112
auto& mem_per_token = lctx.mem_per_token;
107113
auto& buf_compute = lctx.buf_compute;
@@ -164,20 +170,36 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
164170
}
165171

166172
// compute QKV
167-
{
173+
struct ne_tensor* Qcur;
174+
struct ne_tensor* Kcur;
175+
struct ne_tensor* Vcur;
176+
177+
if (qwen_version == 1) {
168178
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
169179

170180
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], cur), cur);
181+
size_t fused_qkv_row_nb = (3 * n_embd) * sizeof(float);
182+
Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb,
183+
0 * sizeof(float) * n_embd));
184+
// head_dim, n_head, N --> head_dim, N, n_head
185+
Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb,
186+
1 * sizeof(float) * n_embd));
187+
// head_dim, n_head, N --> N, head_dim, n_head
188+
Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb,
189+
2 * sizeof(float) * n_embd));
190+
} else {
191+
Qcur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
192+
Qcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], Qcur), Qcur);
193+
Qcur = ne_reshape_3d(ctx0, Qcur, head_dim, n_head, N);
194+
195+
Kcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
196+
Kcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[3], Kcur), Kcur);
197+
Kcur = ne_reshape_3d(ctx0, Kcur, head_dim, n_head, N);
198+
199+
Vcur = ne_mul_mat(ctx0, model.layers[il].attn[4], cur);
200+
Vcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[5], Vcur), Vcur);
201+
Vcur = ne_reshape_3d(ctx0, Vcur, head_dim, n_head, N);
171202
}
172-
size_t fused_qkv_row_nb = (3 * n_embd) * sizeof(float);
173-
struct ne_tensor* Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float),
174-
fused_qkv_row_nb, 0 * sizeof(float) * n_embd));
175-
// head_dim, n_head, N --> head_dim, N, n_head
176-
struct ne_tensor* Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float),
177-
fused_qkv_row_nb, 1 * sizeof(float) * n_embd));
178-
// head_dim, n_head, N --> N, head_dim, n_head
179-
struct ne_tensor* Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float),
180-
fused_qkv_row_nb, 2 * sizeof(float) * n_embd));
181203

182204
// using mode = 2 for GPT-NeoX mode
183205
Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
@@ -300,7 +322,11 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
300322
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
301323
}
302324
// projection
303-
{ cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); }
325+
if (qwen_version == 1) {
326+
cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
327+
} else {
328+
cur = ne_mul_mat(ctx0, model.layers[il].attn[6], cur);
329+
}
304330
}
305331
lctx.use_buf(ctx0, 1);
306332

neural_speed/models/qwen/qwen_utils.cpp

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,16 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
5252
model.hparams = ml->file_loaders.at(0)->hparams;
5353
model_file_version file_version = ml->file_loaders.at(0)->file_version;
5454
auto& hparams = model.hparams;
55-
n_ff = hparams.ffn_hidden_size / 2;
55+
n_ff = hparams.ffn_hidden_size;
56+
if (hparams.max_seq_len == 8192) {
57+
n_ff = n_ff / 2;
58+
}
5659
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
5760
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
5861
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
5962
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
6063
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
61-
fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size / 2);
64+
fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size);
6265
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
6366
n_embd = hparams.n_embd;
6467
n_vocab = hparams.n_vocab;
@@ -102,7 +105,7 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
102105
model.layers.resize(n_layer);
103106
size_t vram_total = 0;
104107

105-
if (ml->verify_tensor("token_embd.weight")) {
108+
if (ml->verify_tensor("token_embd.weight")) { // gguf
106109
model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
107110
model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU);
108111
model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
@@ -117,16 +120,26 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
117120
layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
118121

119122
// qkv GEMM
120-
layer.attn[0] = ml->get_tensor(layers_i + ".attn_qkv.weight", {n_embd, 3 * n_embd}, backend);
121-
layer.attn[1] = ml->get_tensor(layers_i + ".attn_qkv.bias", {3 * n_embd}, backend);
122-
layer.attn[2] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
123+
if (ml->verify_tensor(layers_i + ".attn_qkv.weight")) {
124+
layer.attn[0] = ml->get_tensor(layers_i + ".attn_qkv.weight", {n_embd, 3 * n_embd}, backend);
125+
layer.attn[1] = ml->get_tensor(layers_i + ".attn_qkv.bias", {3 * n_embd}, backend);
126+
layer.attn[2] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
127+
} else { // qwen2 gguf
128+
layer.attn[0] = ml->get_tensor(layers_i + ".attn_q.weight", {n_embd, n_embd}, backend);
129+
layer.attn[1] = ml->get_tensor(layers_i + ".attn_q.bias", {n_embd}, backend);
130+
layer.attn[2] = ml->get_tensor(layers_i + ".attn_k.weight", {n_embd, n_embd}, backend);
131+
layer.attn[3] = ml->get_tensor(layers_i + ".attn_k.bias", {n_embd}, backend);
132+
layer.attn[4] = ml->get_tensor(layers_i + ".attn_v.weight", {n_embd, n_embd}, backend);
133+
layer.attn[5] = ml->get_tensor(layers_i + ".attn_v.bias", {n_embd}, backend);
134+
layer.attn[6] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
135+
}
123136

124137
// ffn GEMM
125138
layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend);
126139
layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend);
127140
layer.ffn[2] = ml->get_tensor(layers_i + ".ffn_down.weight", {n_ff, n_embd}, backend);
128141
}
129-
} else {
142+
} else if (ml->verify_tensor("transformer.wte.weight")) { // qwen1 bin
130143
model.others[0] = ml->get_tensor("transformer.wte.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
131144
model.others[1] = ml->get_tensor("transformer.ln_f.weight", {n_embd}, NE_BACKEND_CPU);
132145
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
@@ -150,6 +163,34 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
150163
layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.w2.weight", {n_embd, n_ff}, backend);
151164
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.c_proj.weight", {n_ff, n_embd}, backend);
152165
}
166+
} else { // qwen2 bin
167+
model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
168+
model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
169+
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
170+
171+
for (uint32_t i = 0; i < n_layer; ++i) {
172+
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
173+
auto& layer = model.layers[i];
174+
std::string layers_i = "model.layers." + std::to_string(i);
175+
176+
// norm: cur = ln_1_g*cur + ln_1_b
177+
layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);
178+
layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);
179+
180+
// qkv GEMM + out proj GEMM
181+
layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend);
182+
layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend);
183+
layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend);
184+
layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend);
185+
layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend);
186+
layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend);
187+
layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);
188+
189+
// ffn GEMM
190+
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend);
191+
layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend);
192+
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend);
193+
}
153194
}
154195

155196
// print memory requirements
@@ -180,7 +221,7 @@ class qwen_quant_layer : public quant_layer_base {
180221
public:
181222
quant_params_internal get_layer_config(std::string layername, std::vector<int64_t> ne, ne_type type) override {
182223
bool quantize = layername.rfind("weight") == layername.size() - 6; // ends with 'weight'?
183-
if (layername == "transformer.wte.weight") {
224+
if (layername == "transformer.wte.weight" || layername == "model.embed_tokens.weight") {
184225
// special layer process, can be loaded by config file
185226
return quant_params_internal(); // return q4_0 to cover the usage of getrow
186227
}

tests/model-test/cpp_graph_inference.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,10 @@ model_name_map["qwen-7b"]="Qwen/Qwen-7B-Chat"
155155
model_name_map["magicoder"]="ise-uiuc/Magicoder-S-DS-6.7B"
156156
model_name_map["whisper"]="openai/whisper-tiny"
157157
model_name_map["phi2"]="microsoft/phi-2"
158+
model_name_map["qwen-1_5"]="Qwen/Qwen1.5-7B-Chat"
158159
model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1"
159160

161+
160162
function main() {
161163
conda_env="$1"
162164
model="$2"
@@ -251,6 +253,10 @@ function main() {
251253
quant_script="./build/bin/quant_qwen"
252254
convert_script="${convert_script}/convert_qwen.py"
253255
infer_cmd="./build/bin/run_qwen"
256+
elif [[ "${model}" == "qwen-1_5" ]]; then
257+
quant_script="./build/bin/quant_qwen"
258+
convert_script="${convert_script}/convert_qwen.py"
259+
infer_cmd="./build/bin/run_qwen"
254260
elif [[ "${model}" == "magicoder" ]]; then
255261
quant_script="./build/bin/quant_llama"
256262
convert_script="${convert_script}/convert_llama.py"

0 commit comments

Comments
 (0)