Skip to content

Commit f5f9121

Browse files
jploskicebtenzzreggerganov
authored
llm : add MPT support (#3417)
* CUDA: added support for ggml_clamp (see also: ggml-org/ggml#545) * mpt : added an implementation based (mostly) on falcon integration, modified with deltas from ggml/examples/mpt * mpt : protect against "clip_qkv": null in mpt-7b * mpt : quick fix to avoid "Strange model" warning when quantizing MPT models * mpt : addendum to changeset:84e30e8 - leave parameter clamp_kqv out from metadata rather than use 0.0 to indicate "no clamping" (more compliant with the current GGUF spec?) * mpt : standardized all tensor names to follow GGUF spec * mpt : addendum to changeset:1be89c40 - use "req" parameter of GGUF_GET_KEY macro instead of duplicate code * mpt : fixed comment s/gptneox/mpt/ * mpt : remove tabs, trailing whitespace * mpt : removed ne01 + n_past == ne00 assertion from alibi (cuda/f32) and rope_shift from build_mpt * mpt : updated convert-mpt-hf-to-gguf.py to reflect changes made to convert-gptneox-hf-to-gguf.py in pr:3252 * comment out n_past instead of marking it unused * mpt : removed hardcoded +178 from convert script in favor of utilizing hparams["vocab_size"] * mpt : remove unused tokenizer_json in convert script * ggml : remove obsolete n_past assert in ggml_alibi * llama : print clam_kqv and max_alibi_bias hparams --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 11ea5c7 commit f5f9121

File tree

5 files changed

+685
-9
lines changed

5 files changed

+685
-9
lines changed

convert-mpt-hf-to-gguf.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#!/usr/bin/env python3
2+
# HF mpt--> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import struct
10+
import sys
11+
from pathlib import Path
12+
from typing import Any
13+
14+
import numpy as np
15+
import torch
16+
from transformers import AutoTokenizer # type: ignore[import]
17+
18+
if 'NO_LOCAL_GGUF' not in os.environ:
19+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
20+
import gguf
21+
22+
23+
def count_model_parts(dir_model: Path) -> int:
24+
num_parts = 0
25+
for filename in os.listdir(dir_model):
26+
if filename.startswith("pytorch_model-"):
27+
num_parts += 1
28+
29+
if num_parts > 0:
30+
print("gguf: found " + str(num_parts) + " model parts")
31+
return num_parts
32+
33+
34+
def parse_args() -> argparse.Namespace:
35+
parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file")
36+
parser.add_argument(
37+
"--vocab-only", action="store_true",
38+
help="extract only the vocab",
39+
)
40+
parser.add_argument(
41+
"--outfile", type=Path,
42+
help="path to write to; default: based on input",
43+
)
44+
parser.add_argument(
45+
"model", type=Path,
46+
help="directory containing model file, or model file itself (*.bin)",
47+
)
48+
parser.add_argument(
49+
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
50+
help="output format - use 0 for float32, 1 for float16",
51+
)
52+
return parser.parse_args()
53+
54+
args = parse_args()
55+
56+
dir_model = args.model
57+
ftype = args.ftype
58+
if not dir_model.is_dir():
59+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
60+
sys.exit(1)
61+
62+
# possible tensor data types
63+
# ftype == 0 -> float32
64+
# ftype == 1 -> float16
65+
66+
# map from ftype to string
67+
ftype_str = ["f32", "f16"]
68+
69+
if args.outfile is not None:
70+
fname_out = args.outfile
71+
else:
72+
# output in the same directory as the model by default
73+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
74+
75+
print("gguf: loading model "+dir_model.name)
76+
77+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
78+
hparams = json.load(f)
79+
80+
if hparams["architectures"][0] != "MPTForCausalLM":
81+
print("Model architecture not supported: " + hparams["architectures"][0])
82+
83+
sys.exit()
84+
85+
# get number of model parts
86+
num_parts = count_model_parts(dir_model)
87+
88+
ARCH=gguf.MODEL_ARCH.MPT
89+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
90+
91+
print("gguf: get model metadata")
92+
93+
block_count = hparams["n_layers"]
94+
95+
gguf_writer.add_name(dir_model.name)
96+
gguf_writer.add_context_length(hparams["max_seq_len"])
97+
gguf_writer.add_embedding_length(hparams["d_model"])
98+
gguf_writer.add_block_count(block_count)
99+
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
100+
gguf_writer.add_head_count(hparams["n_heads"])
101+
gguf_writer.add_layer_norm_eps(1e-05)
102+
if hparams["attn_config"]["clip_qkv"] is not None:
103+
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])
104+
gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"])
105+
106+
# TOKENIZATION
107+
108+
print("gguf: get tokenizer metadata")
109+
110+
tokens: list[bytearray] = []
111+
scores: list[float] = []
112+
toktypes: list[int] = []
113+
114+
# gpt2 tokenizer
115+
gguf_writer.add_tokenizer_model("gpt2")
116+
117+
print("gguf: get gpt2 tokenizer vocab")
118+
119+
# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but
120+
# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to
121+
# accomodate some "reserved" tokens; this is causing problems down the line in
122+
# llama.cpp, so we pad the vocab with dummy tokens:
123+
124+
vocab_size = hparams["vocab_size"]
125+
126+
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
127+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
128+
129+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
130+
131+
for i in range(vocab_size):
132+
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
133+
scores.append(0.0) # dummy
134+
toktypes.append(gguf.TokenType.NORMAL)
135+
136+
gguf_writer.add_token_list(tokens)
137+
gguf_writer.add_token_scores(scores)
138+
gguf_writer.add_token_types(toktypes)
139+
140+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
141+
special_vocab.add_to_gguf(gguf_writer)
142+
143+
# TENSORS
144+
145+
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
146+
147+
# tensor info
148+
print("gguf: get tensor metadata")
149+
150+
if num_parts == 0:
151+
part_names = iter(("pytorch_model.bin",))
152+
else:
153+
part_names = (
154+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
155+
)
156+
157+
for part_name in part_names:
158+
if args.vocab_only:
159+
break
160+
print("gguf: loading model part '" + part_name + "'")
161+
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
162+
163+
for name in model_part.keys():
164+
data = model_part[name]
165+
166+
old_dtype = data.dtype
167+
168+
# convert any unsupported data types to float32
169+
if data.dtype != torch.float16 and data.dtype != torch.float32:
170+
data = data.to(torch.float32)
171+
172+
data = data.squeeze().numpy()
173+
174+
# map tensor names
175+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
176+
if new_name is None:
177+
print("Cannot map tensor '" + name + "'")
178+
continue # for the sake of compatibility with some old published models, don't quit
179+
sys.exit()
180+
181+
n_dims = len(data.shape)
182+
data_dtype = data.dtype
183+
184+
# if f32 desired, convert any float16 to float32
185+
if ftype == 0 and data_dtype == np.float16:
186+
data = data.astype(np.float32)
187+
188+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
189+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
190+
data = data.astype(np.float32)
191+
192+
# if f16 desired, convert any float32 2-dim weight tensors to float16
193+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
194+
data = data.astype(np.float16)
195+
196+
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
197+
198+
gguf_writer.add_tensor(new_name, data)
199+
200+
# note: MPT output is tied to (same as) wte in original model;
201+
# for easier implementation in llama.cpp it's duplicated in GGUF, though :/
202+
if new_name == "token_embd.weight":
203+
gguf_writer.add_tensor("output.weight", data)
204+
205+
print("gguf: write header")
206+
gguf_writer.write_header_to_file()
207+
print("gguf: write metadata")
208+
gguf_writer.write_kv_data_to_file()
209+
if not args.vocab_only:
210+
print("gguf: write tensors")
211+
gguf_writer.write_tensors_to_file()
212+
213+
gguf_writer.close()
214+
215+
print(f"gguf: model successfully exported to '{fname_out}'")
216+
print("")

ggml-cuda.cu

+45-2
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
415415
#define CUDA_SILU_BLOCK_SIZE 256
416416
#define CUDA_CPY_BLOCK_SIZE 32
417417
#define CUDA_SCALE_BLOCK_SIZE 256
418+
#define CUDA_CLAMP_BLOCK_SIZE 256
418419
#define CUDA_ROPE_BLOCK_SIZE 256
419420
#define CUDA_ALIBI_BLOCK_SIZE 32
420421
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
@@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
45854586
dst[i] = scale * x[i];
45864587
}
45874588

4589+
static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
4590+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
4591+
4592+
if (i >= k) {
4593+
return;
4594+
}
4595+
4596+
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
4597+
}
45884598

45894599
template<int qk, int qr, dequantize_kernel_t dq>
45904600
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
@@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
54755485
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
54765486
}
54775487

5488+
static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
5489+
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
5490+
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
5491+
}
5492+
54785493
template<typename T>
54795494
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
54805495
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
@@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi(
64196434
const int64_t ne02 = src0->ne[2];
64206435
const int64_t nrows = ggml_nrows(src0);
64216436

6422-
const int n_past = ((int32_t *) dst->op_params)[0];
6437+
//const int n_past = ((int32_t *) dst->op_params)[0];
64236438
const int n_head = ((int32_t *) dst->op_params)[1];
64246439
float max_bias;
64256440
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
64266441

6427-
GGML_ASSERT(ne01 + n_past == ne00);
6442+
//GGML_ASSERT(ne01 + n_past == ne00);
64286443
GGML_ASSERT(n_head == ne02);
64296444

64306445
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale(
65006515
(void) src1_dd;
65016516
}
65026517

6518+
inline void ggml_cuda_op_clamp(
6519+
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6520+
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6521+
6522+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
6523+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
6524+
6525+
const float min = ((float *) dst->op_params)[0];
6526+
const float max = ((float *) dst->op_params)[1];
6527+
6528+
clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
6529+
CUDA_CHECK(cudaGetLastError());
6530+
6531+
(void) src1;
6532+
(void) dst;
6533+
(void) src1_dd;
6534+
}
6535+
65036536
static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
65046537
const int64_t nrows0 = ggml_nrows(src0);
65056538

@@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
70617094
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
70627095
}
70637096

7097+
static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7098+
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
7099+
}
7100+
70647101
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
70657102
const int64_t ne = ggml_nelements(src0);
70667103
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
74707507
case GGML_OP_SCALE:
74717508
func = ggml_cuda_scale;
74727509
break;
7510+
case GGML_OP_CLAMP:
7511+
if (!any_on_device) {
7512+
return false;
7513+
}
7514+
func = ggml_cuda_clamp;
7515+
break;
74737516
case GGML_OP_CPY:
74747517
func = ggml_cuda_cpy;
74757518
break;

ggml-metal.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute(
12991299

13001300
const int nth = MIN(1024, ne00);
13011301

1302-
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1302+
//const int n_past = ((int32_t *) dst->op_params)[0];
13031303
const int n_head = ((int32_t *) dst->op_params)[1];
13041304
float max_bias;
13051305
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

ggml.c

+1-3
Original file line numberDiff line numberDiff line change
@@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32(
1305913059
return;
1306013060
}
1306113061

13062-
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
13062+
//const int n_past = ((int32_t *) dst->op_params)[0];
1306313063
const int n_head = ((int32_t *) dst->op_params)[1];
1306413064
float max_bias;
1306513065
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1306613066

13067-
assert(n_past >= 0);
13068-
1306913067
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
1307013068
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
1307113069
const int64_t ne2 = src0->ne[2]; // n_head -> this is k

0 commit comments

Comments
 (0)