Skip to content

Commit 4e54fd9

Browse files
committed
Make loading weights 10-100x faster
This is a breaking change that's going to give you three benefits: 1. Your inference commands should load 100x faster 2. You may be able to safely load models 2x larger 3. You can run many concurrent inference processes This was accomplished by changing the file format so we can mmap() weights directly into memory without having to read() or copy them thereby ensuring the kernel can make its file cache pages directly accessible to our inference processes; and secondly, that the file cache pages are much less likely to get evicted (which would force loads to hit disk) because they're no longer competing with memory pages that were needlessly created by gigabytes of standard i/o. Furthermore, this change ensures that tensors are aligned properly on a 32-byte boundary. That opens the door to seeing if we can get additional performance gains on some microprocessors, by using ops that require memory alignment. Lastly note that both POSIX and the Windows platform are supported Fixes ggml-org#91
1 parent 80c2178 commit 4e54fd9

File tree

4 files changed

+310
-373
lines changed

4 files changed

+310
-373
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ models/*
2222
/result
2323
/perplexity
2424
/embedding
25+
/Pipfile
2526

2627
arm_neon.h
2728
compile_commands.json

Diff for: convert-pth-to-ggml.py

+149-52
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,64 @@
2424

2525
from sentencepiece import SentencePieceProcessor
2626

27-
def parse_args():
27+
QK = 32
28+
29+
GGML_TYPE_Q4_0 = 0
30+
GGML_TYPE_Q4_1 = 1
31+
GGML_TYPE_I8 = 2
32+
GGML_TYPE_I16 = 3
33+
GGML_TYPE_I32 = 4
34+
GGML_TYPE_F16 = 5
35+
GGML_TYPE_F32 = 6
36+
37+
WTYPES = {
38+
0: GGML_TYPE_F32,
39+
1: GGML_TYPE_F16,
40+
2: GGML_TYPE_Q4_0,
41+
3: GGML_TYPE_Q4_1,
42+
}
43+
44+
GGML_BLCK_SIZE = {
45+
GGML_TYPE_Q4_0: QK,
46+
GGML_TYPE_Q4_1: QK,
47+
GGML_TYPE_I8: 1,
48+
GGML_TYPE_I16: 1,
49+
GGML_TYPE_I32: 1,
50+
GGML_TYPE_F16: 1,
51+
GGML_TYPE_F32: 1,
52+
}
53+
54+
GGML_TYPE_SIZE = {
55+
GGML_TYPE_Q4_0: 4 + QK/2,
56+
GGML_TYPE_Q4_1: 4*2 + QK/2,
57+
GGML_TYPE_I8: 1,
58+
GGML_TYPE_I16: 2,
59+
GGML_TYPE_I32: 4,
60+
GGML_TYPE_F16: 2,
61+
GGML_TYPE_F32: 4,
62+
}
63+
64+
def ggml_nelements(shape):
65+
r = 1
66+
for i in shape:
67+
r *= i
68+
return r
69+
70+
def ggml_nbytes(shape, ftype):
71+
x = ggml_nelements(shape)
72+
t = WTYPES[ftype]
73+
x *= GGML_TYPE_SIZE[t]
74+
x //= GGML_BLCK_SIZE[t]
75+
return x
2876

77+
def parse_args():
2978
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
3079
parser.add_argument('dir_model', help='directory containing the model checkpoint')
3180
parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
3281
parser.add_argument('vocab_only', help='only write vocab to file', type=int, default=0, nargs='?')
3382
return parser.parse_args()
3483

3584
def get_n_parts(dim):
36-
3785
mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8}
3886
n_parts = mappings.get(dim)
3987
if n_parts is None:
@@ -44,30 +92,24 @@ def get_n_parts(dim):
4492
return n_parts
4593

4694
def load_hparams_and_tokenizer(dir_model):
47-
4895
# `dir_model` is something like `models/7B` or `models/7B/`.
4996
# "tokenizer.model" is expected under model's parent dir.
5097
# When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found.
5198
# Let's use the model's parent dir directly.
5299
model_parent_dir = os.path.dirname(os.path.normpath(dir_model))
53-
54100
fname_hparams = f"{dir_model}/params.json"
55101
fname_tokenizer = f"{model_parent_dir}/tokenizer.model"
56-
57102
with open(fname_hparams, "r") as f:
58103
hparams = json.load(f)
59104
print(hparams)
60-
61105
tokenizer = SentencePieceProcessor(fname_tokenizer)
62106
hparams.update({"vocab_size": tokenizer.vocab_size()})
63-
64107
return hparams, tokenizer
65108

66109
def write_header(fout, hparams, ftype):
67-
68110
keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
69111
values = [
70-
0x67676d66, # magic: ggmf in hex
112+
0x67676a74, # magic: ggjt in hex
71113
1, # file version
72114
*[hparams[key] for key in keys],
73115
hparams["dim"] // hparams["n_heads"], # rot (obsolete)
@@ -76,7 +118,6 @@ def write_header(fout, hparams, ftype):
76118
fout.write(struct.pack("i" * len(values), *values))
77119

78120
def write_tokens(fout, tokenizer):
79-
80121
for i in range(tokenizer.vocab_size()):
81122
if tokenizer.is_unknown(i):
82123
text = " \u2047 ".encode("utf-8")
@@ -95,85 +136,141 @@ def write_tokens(fout, tokenizer):
95136
fout.write(text)
96137
fout.write(struct.pack("f", tokenizer.get_score(i)))
97138

98-
def process_and_write_variables(fout, model, ftype):
99-
139+
def process_and_write_variables(fout, model, ftype, part_id, n_parts):
100140
for name, datao in model.items():
101-
102141
if name.endswith("freqs"):
103142
continue
104143

105-
shape = datao.shape
106-
107-
print(f"Processing variable: {name} with shape: {shape} and type: {datao.dtype}")
108-
144+
# remove dimensions with a single element
109145
data = datao.numpy().squeeze()
110-
n_dims = len(shape)
146+
partshape = data.shape
147+
n_dims = len(data.shape)
148+
assert n_dims in (1, 2)
149+
150+
print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}")
111151

112-
# default type is fp16
152+
# coerce single-dimensional tensors from float16 to float32
113153
ftype_cur = 1
114154
if ftype == 0 or n_dims == 1:
115155
print(" Converting to float32")
116156
data = data.astype(np.float32)
117157
ftype_cur = 0
118-
119-
# header
158+
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
159+
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
160+
161+
# determine dimension along which multipart tensor is sharded
162+
#
163+
# split_dim 0 regex:
164+
# - output.*
165+
# - layers.*.attention.wq.weight
166+
# - layers.*.attention.wk.weight
167+
# - layers.*.attention.wv.weight
168+
# - layers.*.feed_forward.w1.weight
169+
# - layers.*.feed_forward.w3.weight
170+
#
171+
# split_dim 1 regex:
172+
# - tok_embeddings.*
173+
# - layers.*.attention.wo.weight
174+
# - layers.*.feed_forward.w2.weight
175+
#
176+
if n_dims > 1:
177+
split_dim = 1
178+
if "tok_embeddings" in name:
179+
split_dim = 1
180+
elif "layers" in name:
181+
if "attention.wo.weight" in name:
182+
split_dim = 1
183+
elif "feed_forward.w2.weight" in name:
184+
split_dim = 1
185+
else:
186+
split_dim = 0
187+
elif "output" in name:
188+
split_dim = 0
189+
190+
# output tensor header
191+
fullshape = list(partshape)
192+
if n_dims > 1:
193+
fullshape[split_dim] *= n_parts
120194
sname = name.encode('utf-8')
121-
fout.write(struct.pack("iii", len(data.shape), len(sname), ftype_cur))
122-
for dim in reversed(data.shape):
195+
fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur))
196+
for dim in reversed(fullshape):
123197
fout.write(struct.pack("i", dim))
124198
fout.write(sname)
125199

126-
# data output to file
127-
data.tofile(fout)
200+
# ensure tensor data is aligned
201+
tensor_data_offset = fout.tell()
202+
while tensor_data_offset % QK != 0:
203+
fout.write(struct.pack("B", 0))
204+
tensor_data_offset += 1
205+
206+
# output unified mappable tensor data
207+
if n_dims == 1 or n_parts == 1:
208+
# copy tensor which we thankfully received in one piece
209+
if part_id == 0:
210+
data.tofile(fout)
211+
elif split_dim == 0:
212+
# reassemble multifile tensor containing some of the rows
213+
rows_per_chunk = partshape[0]
214+
current_row = part_id * rows_per_chunk
215+
bytes_per_row = fullshape[1] // blck_size * type_size
216+
offset = current_row * bytes_per_row
217+
fout.seek(tensor_data_offset + offset)
218+
data.tofile(fout)
219+
elif split_dim == 1:
220+
# reassemble multifile tensor containing some of the cols
221+
cols_per_chunk = partshape[1]
222+
current_col = part_id * cols_per_chunk
223+
bytes_per_row = fullshape[1] // blck_size * type_size
224+
offset_current_col = current_col // blck_size * type_size
225+
for row in range(partshape[0]):
226+
offset_row = row * bytes_per_row
227+
offset = offset_row + offset_current_col
228+
fout.seek(tensor_data_offset + offset)
229+
data[row].tofile(fout)
230+
231+
# advance file position to next tensor
232+
fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur))
128233

129234
def main():
130-
131235
args = parse_args()
132236
dir_model = args.dir_model
133237
ftype = args.ftype
134238
ftype_str = ["f32", "f16"]
135-
136239
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
137240

138241
print(args)
139242

140243
# if only writing vocab to file
141244
if args.vocab_only:
142-
143245
fname_model = f"{dir_model}/consolidated.00.pth"
144246
fname_out = f"{dir_model}/ggml-vocab.bin"
145-
146247
print(f"Extracting only the vocab from '{fname_model}'\n")
147-
148-
248+
model = torch.load(fname_model, map_location="cpu")
149249
with open(fname_out, "wb") as fout:
150250
write_header(fout, hparams, ftype)
151251
write_tokens(fout, tokenizer)
152-
153-
252+
del model
154253
print(f"Done. Output file: {fname_out}\n")
155-
156254
return
157255

158256
n_parts = get_n_parts(hparams["dim"])
159-
160-
for p in range(n_parts):
161-
162-
print(f"Processing part {p+1} of {n_parts}\n")
163-
164-
fname_model = f"{dir_model}/consolidated.0{p}.pth"
165-
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin{'' if p == 0 else '.' + str(p)}"
166-
167-
model = torch.load(fname_model, map_location="cpu")
168-
169-
with open(fname_out, "wb") as fout:
170-
write_header(fout, hparams, ftype)
171-
write_tokens(fout, tokenizer)
172-
process_and_write_variables(fout, model, ftype)
173-
174-
del model
175-
176-
print(f"Done. Output file: {fname_out}, (part {p})\n")
257+
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
258+
259+
# we output a single file for ggml
260+
with open(fname_out, "wb") as fout:
261+
write_header(fout, hparams, ftype)
262+
write_tokens(fout, tokenizer)
263+
offset_of_tensors = fout.tell()
264+
# the tensors we load could be split across multiple files
265+
for part_id in range(n_parts):
266+
fout.seek(offset_of_tensors)
267+
print(f"Processing part {part_id+1} of {n_parts}\n")
268+
fname_model = f"{dir_model}/consolidated.0{part_id}.pth"
269+
model = torch.load(fname_model, map_location="cpu")
270+
process_and_write_variables(fout, model, ftype, part_id, n_parts)
271+
del model
272+
273+
print(f"Done. Output file: {fname_out}\n")
177274

178275
if __name__ == "__main__":
179276
main()

0 commit comments

Comments
 (0)