-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert-hf-to-powerinfer-gguf.py
633 lines (514 loc) · 22.3 KB
/
convert-hf-to-powerinfer-gguf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
#!/usr/bin/env python3
from __future__ import annotations
from abc import ABC, abstractmethod
import argparse
import contextlib
import json
import os
import re
import struct
import sys
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Optional, cast
import numpy as np
import torch
import torch.nn as tnn
from dataclasses import dataclass
if TYPE_CHECKING:
from torch import Tensor
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
import gguf
###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
class ReluMLP(tnn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super(ReluMLP, self).__init__()
self.fc1 = tnn.Linear(input_dim, hidden_dim, bias=False)
self.relu = tnn.ReLU()
self.fc2 = tnn.Linear(hidden_dim, output_dim, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
@staticmethod
def from_file(model_file: Path):
model = torch.load(model_file, map_location="cpu")
hidden_size, input_size = model.get("fc1.weight").shape
output_size, _ = model.get("fc2.weight").shape
mlp = ReluMLP(input_size, hidden_size, output_size)
mlp.load_state_dict(model)
return mlp
class Model(ABC):
"""Base class for model conversion"""
def __init__(
self,
dir_model: Path,
dir_mlp_pred: Path,
ftype: int,
fname_out: Path,
is_big_endian: bool,
):
self.dir_model = dir_model
self.dir_mlp_pred = dir_mlp_pred
self.ftype = ftype
self.fname_out = fname_out
self.is_big_endian = is_big_endian
self.endianess = (
gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
)
self.is_safetensors = self._is_model_safetensors()
self.num_parts = Model.count_model_parts(
self.dir_model, ".safetensors" if self.is_safetensors else ".bin"
)
self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture()
self.gguf_writer = gguf.GGUFWriter(
fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file = False
)
def set_vocab(self):
self._set_vocab_gpt2()
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for model_layer, part_name in self._get_mlp_part_layer_names():
print(f"gguf: loading mlp part '{part_name}'")
mlp_model = ReluMLP.from_file(self.dir_mlp_pred / part_name)
for name, data in mlp_model.state_dict().items():
yield f"blk.{model_layer}.{name}", data
for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
from safetensors import safe_open
ctx = cast(
ContextManager[Any],
safe_open(self.dir_model / part_name, framework="pt", device="cpu"),
)
else:
ctx = contextlib.nullcontext(
torch.load(self.dir_model / part_name, map_location="cpu")
)
with ctx as model_part:
for name in model_part.keys():
data = (
model_part.get_tensor(name)
if self.is_safetensors
else model_part[name]
)
yield name, data
@abstractmethod
def set_gguf_parameters(self):
pass
# self.gguf_writer.add_name(self.dir_model.name)
# self.gguf_writer.add_block_count(
# self.hparams.get(
# "n_layers",
# self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
# )
# )
# if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
# self.gguf_writer.add_context_length(n_ctx)
# if (n_embd := self.hparams.get("hidden_size")) is not None:
# self.gguf_writer.add_embedding_length(n_embd)
# if (n_ff := self.hparams.get("intermediate_size")) is not None:
# self.gguf_writer.add_feed_forward_length(n_ff)
# if (n_head := self.hparams.get("num_attention_head")) is not None:
# self.gguf_writer.add_head_count(n_head)
# self.gguf_writer.add_parallel_residual(
# self.hparams.get("use_parallel_residual", True)
# )
@abstractmethod
def write_tensors(self):
pass
def write(self):
self.write_tensors()
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file()
self.gguf_writer.close()
def write_vocab(self):
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close()
@staticmethod
def count_model_parts(dir_model: Path, prefix: str) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.endswith(prefix):
num_parts += 1
return num_parts
@staticmethod
def load_hparams(dir_model):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def from_model_architecture(model_architecture):
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
return FalconModel
if model_architecture == "LlamaForCausalLM":
return LlamaModel
raise NotImplementedError(f'Architecture "{model_architecture}" not supported!')
def _is_model_safetensors(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
def _get_mlp_part_layer_names(self):
"""Returns a generator of (index, name) for MLP predictors of each model layer"""
n_mlp_parts = Model.count_model_parts(self.dir_mlp_pred, ".pt")
return ((n, f"model_{n}.pt") for n in range(n_mlp_parts))
def _get_part_names(self):
if self.is_safetensors:
if self.num_parts == 1: # there's only one .safetensors file
return ("model.safetensors",)
return (
f"model-{n:05}-of-{self.num_parts:05}.safetensors"
for n in range(1, self.num_parts + 1)
)
if self.num_parts == 1: # there's only one .bin file
return ("pytorch_model.bin",)
return (
f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin"
for n in range(1, self.num_parts + 1)
)
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0]
if arch == "FalconForCausalLM":
return gguf.MODEL_ARCH.FALCON
if arch == "RWForCausalLM" or arch == "LlamaForCausalLM":
return gguf.MODEL_ARCH.LLAMA
raise NotImplementedError(f'Architecture "{arch}" not supported!')
def _translate_tensor_key(
self, key: str, try_suffixes=(".weight", ".bias")
) -> Optional[str]:
block_count = self.hparams.get(
"n_layers",
self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
)
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
arch_tensor_key = tensor_map.get_name(key, try_suffixes=try_suffixes)
if arch_tensor_key is not None:
return arch_tensor_key
# check and handle ReluMLP layers
mlp_match = re.match(r"^blk\.\d+\.fc\d\.weight$", key)
if mlp_match:
return mlp_match.group(0)
return None
def _set_vocab_gpt2(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
from transformers import AutoTokenizer # type: ignore[attr-defined]
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
reverse_vocab = {
id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()
}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode("utf-8")
tokens.append(bytearray(pad_token))
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_sentencepiece(self):
from sentencepiece import SentencePieceProcessor
tokenizer_path = self.dir_model / "tokenizer.model"
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
if not tokenizer_path.is_file():
print(f"Error: Missing {tokenizer_path}", file=sys.stderr)
sys.exit(1)
tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size())
for token_id in range(vocab_size):
piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)
toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
added_tokens_file = self.dir_model / "added_tokens.json"
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
tokens.append(key.encode("utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
class LlamaModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()
def set_gguf_parameters(self, params: PredictorParams):
self.gguf_writer.add_name("Llama")
self.gguf_writer.add_context_length(2048) # not in config.json
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_file_type(self.ftype)
if params.sparse_threshold is not None:
self.gguf_writer.add_sparse_threshold(params.sparse_threshold)
def write_tensors(self):
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith(
(
".attention.masked_bias",
".attention.bias",
".attention.rotary_emb.inv_freq",
)
):
continue
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = self._translate_tensor_key(name)
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
# We need to transpose the weight matrices for the FFN Down layers to support the
# Axpy operation in PowerInfer. So we don't need to transpose them at runtime.
if "ffn_down" in new_name:
new_name = new_name.replace("ffn_down", "ffn_down_t")
data = data.T
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if (
self.ftype == 1
and data_dtype == np.float32
and name.endswith(".weight")
and n_dims == 2
):
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
class FalconModel(Model):
def set_gguf_parameters(self, params: PredictorParams):
block_count = self.hparams.get("num_hidden_layers")
if block_count is None:
block_count = self.hparams["n_layer"] # old name
n_head = self.hparams.get("num_attention_heads")
if n_head is None:
n_head = self.hparams["n_head"] # old name
n_head_kv = self.hparams.get("num_kv_heads")
if n_head_kv is None:
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
self.gguf_writer.add_name("Falcon")
self.gguf_writer.add_context_length(2048) # not in config.json
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
if params.sparse_threshold is not None:
self.gguf_writer.add_sparse_threshold(params.sparse_threshold)
def write_tensors(self):
n_head = self.hparams.get("num_attention_heads")
if n_head is None:
n_head = self.hparams["n_head"] # old name
n_head_kv = self.hparams.get("num_kv_heads")
if n_head_kv is None:
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
head_dim = self.hparams["hidden_size"] // n_head
for name, data_torch in self.get_tensors():
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
# QKV tensor transform
# The original query_key_value tensor contains n_head_kv "kv groups",
# each consisting of n_head/n_head_kv query weights followed by one key
# and one value weight (shared by all query heads in the kv group).
# This layout makes it a big pain to work with in GGML.
# So we rearrange them here,, so that we have n_head query weights
# followed by n_head_kv key weights followed by n_head_kv value weights,
# in contiguous fashion.
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
if "query_key_value" in name:
qkv = data_torch.view(
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head
)
q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = self._translate_tensor_key(name)
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
# We need to transpose the weight matrices for the FFN Down layers to support the
# Axpy operation in PowerInfer. So we don't need to transpose them at runtime.
if "ffn_down" in new_name:
new_name = new_name.replace("ffn_down", "ffn_down_t")
data = data.T
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if (
self.ftype == 1
and data_dtype == np.float32
and name.endswith(".weight")
and n_dims == 2
):
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
@dataclass
class PredictorParams:
sparse_threshold: float | None = None
@staticmethod
def loadPredictorJson(config_path: Path) -> PredictorParams:
config = json.load(open(config_path))
return PredictorParams(
sparse_threshold = config.get("sparse_threshold"),
)
@staticmethod
def load(model_instance: Model) -> PredictorParams:
config_path = model_instance.dir_mlp_pred / "config.json"
if config_path.exists():
params = PredictorParams.loadPredictorJson(config_path)
else:
params = PredictorParams()
return params
###### CONVERSION LOGIC ######
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file"
)
parser.add_argument(
"--vocab-only",
action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile",
type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"--outtype",
type=str,
choices=["f32", "f16"],
default="f16",
help="output format - use f32 for float32, f16 for float16",
)
parser.add_argument(
"--bigendian",
action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"model",
type=Path,
help="directory containing model file",
)
parser.add_argument(
"mlp_predictors",
type=Path,
help="directory containing MLP predictors for model",
)
return parser.parse_args()
args = parse_args()
dir_model = args.model
dir_mlp_pred = args.mlp_predictors
if not dir_model.is_dir():
print(f"Error: {args.model} is not a directory", file=sys.stderr)
sys.exit(1)
if not dir_mlp_pred.is_dir():
print(f"Error: {args.mlp_predictors} is not a directory", file=sys.stderr)
sys.exit(1)
ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
}
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f"ggml-model-{args.outtype}.gguf"
print(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model)
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(
dir_model, dir_mlp_pred, ftype_map[args.outtype], fname_out, args.bigendian
)
print("Set model parameters")
params = PredictorParams.load(model_instance)
model_instance.set_gguf_parameters(params)
print("Set model tokenizer")
model_instance.set_vocab()
if args.vocab_only:
print(f"Exporting model vocab to '{fname_out}'")
model_instance.write_vocab()
else:
print(f"Exporting model to '{fname_out}'")
model_instance.write()
# post-process: write another unique file header to distinguish from the origianl GGUF file
with open(fname_out, "r+b") as fout:
POWERINFER_MAGIC = int.from_bytes(b"PWRI", "little")
fout.write(struct.pack("<I", POWERINFER_MAGIC))
print(f"Model successfully exported to '{fname_out}'")