Skip to content

Commit 0099b74

Browse files
cebtenzzreyusiwen
authored andcommitted
gguf : general usability improvements (ggml-org#3409)
1 parent cd8779f commit 0099b74

File tree

5 files changed

+120
-101
lines changed

5 files changed

+120
-101
lines changed

convert.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141

4242
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
4343

44-
ARCH=gguf.MODEL_ARCH.LLAMA
45-
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
44+
ARCH = gguf.MODEL_ARCH.LLAMA
4645

4746
DEFAULT_CONCURRENCY = 8
4847
#
@@ -953,7 +952,7 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM
953952
of.close()
954953

955954
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
956-
wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
955+
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type
957956

958957
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
959958
return GGMLFileType.AllF32

examples/finetune/convert-finetune-checkpoint-to-gguf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def save_gguf(self, gguf_writer):
313313
gguf_writer.add_feed_forward_length(self.get_n_ff())
314314

315315
def tensor_name(key, bid=None, suffix=".weight"):
316-
return gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][key].format(bid=bid) + suffix
316+
return gguf.TENSOR_NAMES[key].format(bid=bid) + suffix
317317

318318
class Layer:
319319
def __init__(self, params, lora_params, bid):

examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def save_gguf(self, gguf_writer):
364364
gguf_writer.add_feed_forward_length(self.get_n_ff())
365365

366366
def tensor_name(key, bid=None):
367-
return gguf.MODEL_TENSOR_NAMES[gguf.MODEL_ARCH.LLAMA][key].format(bid=bid) + ".weight"
367+
return gguf.TENSOR_NAMES[key].format(bid=bid) + ".weight"
368368

369369
class Layer:
370370
def __init__(self, params, bid):

gguf-py/gguf/gguf.py

+115-95
Original file line numberDiff line numberDiff line change
@@ -118,76 +118,97 @@ class MODEL_TENSOR(IntEnum):
118118
MODEL_ARCH.STARCODER: "starcoder",
119119
}
120120

121-
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
122-
MODEL_ARCH.LLAMA: {
123-
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
124-
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
125-
MODEL_TENSOR.OUTPUT: "output",
126-
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
127-
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
128-
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
129-
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
130-
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
131-
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
132-
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
133-
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
134-
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
135-
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
136-
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
137-
},
138-
MODEL_ARCH.GPTNEOX: {
139-
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
140-
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
141-
MODEL_TENSOR.OUTPUT: "output",
142-
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
143-
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
144-
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
145-
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
146-
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
147-
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
148-
},
149-
MODEL_ARCH.FALCON: {
150-
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
151-
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
152-
MODEL_TENSOR.OUTPUT: "output",
153-
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
154-
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
155-
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
156-
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
157-
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
158-
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
159-
},
160-
MODEL_ARCH.BAICHUAN: {
161-
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
162-
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
163-
MODEL_TENSOR.OUTPUT: "output",
164-
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
165-
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
166-
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
167-
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
168-
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
169-
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
170-
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
171-
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
172-
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
173-
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
174-
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
175-
},
176-
MODEL_ARCH.STARCODER: {
177-
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
178-
MODEL_TENSOR.POS_EMBD: "position_embd",
179-
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
180-
MODEL_TENSOR.OUTPUT: "output",
181-
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
182-
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
183-
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
184-
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
185-
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
186-
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
187-
},
188-
MODEL_ARCH.GPT2: {
121+
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
122+
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
123+
MODEL_TENSOR.POS_EMBD: "position_embd",
124+
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
125+
MODEL_TENSOR.OUTPUT: "output",
126+
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
127+
128+
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
129+
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
130+
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
131+
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
132+
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
133+
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
134+
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
135+
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
136+
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
137+
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
138+
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
139+
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
140+
}
141+
142+
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
143+
MODEL_ARCH.LLAMA: [
144+
MODEL_TENSOR.TOKEN_EMBD,
145+
MODEL_TENSOR.OUTPUT_NORM,
146+
MODEL_TENSOR.OUTPUT,
147+
MODEL_TENSOR.ROPE_FREQS,
148+
MODEL_TENSOR.ATTN_NORM,
149+
MODEL_TENSOR.ATTN_Q,
150+
MODEL_TENSOR.ATTN_K,
151+
MODEL_TENSOR.ATTN_V,
152+
MODEL_TENSOR.ATTN_OUT,
153+
MODEL_TENSOR.ATTN_ROT_EMBD,
154+
MODEL_TENSOR.FFN_NORM,
155+
MODEL_TENSOR.FFN_GATE,
156+
MODEL_TENSOR.FFN_DOWN,
157+
MODEL_TENSOR.FFN_UP,
158+
],
159+
MODEL_ARCH.GPTNEOX: [
160+
MODEL_TENSOR.TOKEN_EMBD,
161+
MODEL_TENSOR.OUTPUT_NORM,
162+
MODEL_TENSOR.OUTPUT,
163+
MODEL_TENSOR.ATTN_NORM,
164+
MODEL_TENSOR.ATTN_QKV,
165+
MODEL_TENSOR.ATTN_OUT,
166+
MODEL_TENSOR.FFN_NORM,
167+
MODEL_TENSOR.FFN_DOWN,
168+
MODEL_TENSOR.FFN_UP,
169+
],
170+
MODEL_ARCH.FALCON: [
171+
MODEL_TENSOR.TOKEN_EMBD,
172+
MODEL_TENSOR.OUTPUT_NORM,
173+
MODEL_TENSOR.OUTPUT,
174+
MODEL_TENSOR.ATTN_NORM,
175+
MODEL_TENSOR.ATTN_NORM_2,
176+
MODEL_TENSOR.ATTN_QKV,
177+
MODEL_TENSOR.ATTN_OUT,
178+
MODEL_TENSOR.FFN_DOWN,
179+
MODEL_TENSOR.FFN_UP,
180+
],
181+
MODEL_ARCH.BAICHUAN: [
182+
MODEL_TENSOR.TOKEN_EMBD,
183+
MODEL_TENSOR.OUTPUT_NORM,
184+
MODEL_TENSOR.OUTPUT,
185+
MODEL_TENSOR.ROPE_FREQS,
186+
MODEL_TENSOR.ATTN_NORM,
187+
MODEL_TENSOR.ATTN_Q,
188+
MODEL_TENSOR.ATTN_K,
189+
MODEL_TENSOR.ATTN_V,
190+
MODEL_TENSOR.ATTN_OUT,
191+
MODEL_TENSOR.ATTN_ROT_EMBD,
192+
MODEL_TENSOR.FFN_NORM,
193+
MODEL_TENSOR.FFN_GATE,
194+
MODEL_TENSOR.FFN_DOWN,
195+
MODEL_TENSOR.FFN_UP,
196+
],
197+
MODEL_ARCH.STARCODER: [
198+
MODEL_TENSOR.TOKEN_EMBD,
199+
MODEL_TENSOR.POS_EMBD,
200+
MODEL_TENSOR.OUTPUT_NORM,
201+
MODEL_TENSOR.OUTPUT,
202+
MODEL_TENSOR.ATTN_NORM,
203+
MODEL_TENSOR.ATTN_QKV,
204+
MODEL_TENSOR.ATTN_OUT,
205+
MODEL_TENSOR.FFN_NORM,
206+
MODEL_TENSOR.FFN_DOWN,
207+
MODEL_TENSOR.FFN_UP,
208+
],
209+
MODEL_ARCH.GPT2: [
189210
# TODO
190-
},
211+
],
191212
# TODO
192213
}
193214

@@ -338,28 +359,24 @@ class TensorNameMap:
338359

339360
mapping: dict[str, tuple[MODEL_TENSOR, str]]
340361

341-
tensor_names: dict[MODEL_TENSOR, str]
342-
343362
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
344-
mapping = self.mapping = {}
345-
tensor_names = self.tensor_names = MODEL_TENSOR_NAMES[arch]
363+
self.mapping = {}
346364
for tensor, keys in self.mappings_cfg.items():
347-
tensor_name = tensor_names.get(tensor)
348-
if tensor_name is None:
365+
if tensor not in MODEL_TENSORS[arch]:
349366
continue
350-
mapping[tensor_name] = (tensor, tensor_name)
367+
tensor_name = TENSOR_NAMES[tensor]
368+
self.mapping[tensor_name] = (tensor, tensor_name)
351369
for key in keys:
352-
mapping[key] = (tensor, tensor_name)
370+
self.mapping[key] = (tensor, tensor_name)
353371
for bid in range(n_blocks):
354372
for tensor, keys in self.block_mappings_cfg.items():
355-
tensor_name = tensor_names.get(tensor)
356-
if tensor_name is None:
373+
if tensor not in MODEL_TENSORS[arch]:
357374
continue
358-
tensor_name = tensor_name.format(bid = bid)
359-
mapping[tensor_name] = (tensor, tensor_name)
375+
tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
376+
self.mapping[tensor_name] = (tensor, tensor_name)
360377
for key in keys:
361378
key = key.format(bid = bid)
362-
mapping[key] = (tensor, tensor_name)
379+
self.mapping[key] = (tensor, tensor_name)
363380

364381
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
365382
result = self.mapping.get(key)
@@ -800,22 +817,25 @@ class SpecialVocab:
800817
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
801818
special_token_ids: dict[str, int] = {}
802819

803-
def __init__(self, path: Path, load_merges: bool = False, special_token_types: tuple[str, ...] | None = None):
820+
def __init__(
821+
self, path: str | os.PathLike[str], load_merges: bool = False,
822+
special_token_types: tuple[str, ...] | None = None,
823+
):
804824
self.special_token_ids = {}
805825
self.load_merges = load_merges
806826
if special_token_types is not None:
807827
self.special_token_types = special_token_types
808-
self.load(path)
828+
self._load(Path(path))
809829

810-
def load(self, path: Path):
811-
if not self.try_load_from_tokenizer_json(path):
812-
self.try_load_from_config_json(path)
830+
def _load(self, path: Path) -> None:
831+
if not self._try_load_from_tokenizer_json(path):
832+
self._try_load_from_config_json(path)
813833

814-
def try_load_from_tokenizer_json(self, path: Path) -> bool:
834+
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
815835
tokenizer_file = path / 'tokenizer.json'
816836
if not tokenizer_file.is_file():
817837
return False
818-
with open(tokenizer_file, 'r', encoding = 'utf-8') as f:
838+
with open(tokenizer_file, encoding = 'utf-8') as f:
819839
tokenizer = json.load(f)
820840
if self.load_merges:
821841
merges = tokenizer.get('model', {}).get('merges')
@@ -825,7 +845,7 @@ def try_load_from_tokenizer_json(self, path: Path) -> bool:
825845
added_tokens = tokenizer.get('added_tokens')
826846
if added_tokens is None or not tokenizer_config_file.is_file():
827847
return True
828-
with open(tokenizer_config_file, 'r', encoding = 'utf-8') as f:
848+
with open(tokenizer_config_file, encoding = 'utf-8') as f:
829849
tokenizer_config = json.load(f)
830850
for typ in self.special_token_types:
831851
entry = tokenizer_config.get(f'{typ}_token')
@@ -844,19 +864,19 @@ def try_load_from_tokenizer_json(self, path: Path) -> bool:
844864
break
845865
return True
846866

847-
def try_load_from_config_json(self, path: Path) -> bool:
867+
def _try_load_from_config_json(self, path: Path) -> bool:
848868
config_file = path / 'config.json'
849869
if not config_file.is_file():
850870
return False
851-
with open(config_file, 'r', encoding = 'utf-8') as f:
871+
with open(config_file, encoding = 'utf-8') as f:
852872
config = json.load(f)
853873
for typ in self.special_token_types:
854874
maybe_token_id = config.get(f'{typ}_token_id')
855875
if isinstance(maybe_token_id, int) and maybe_token_id >= 0:
856876
self.special_token_ids[typ] = maybe_token_id
857877
return True
858878

859-
def add_to_gguf(self, gw: GGUFWriter):
879+
def add_to_gguf(self, gw: GGUFWriter) -> None:
860880
if len(self.merges) > 0:
861881
print(f'gguf: Adding {len(self.merges)} merge(s).')
862882
gw.add_token_merges(self.merges)
@@ -868,8 +888,8 @@ def add_to_gguf(self, gw: GGUFWriter):
868888
print(f'gguf: Setting special token type {typ} to {tokid}')
869889
handler(tokid)
870890

871-
def __repr__(self):
872-
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids if self.special_token_ids else "unset"}>'
891+
def __repr__(self) -> str:
892+
return f'<SpecialVocab with {len(self.merges)} merges and special tokens {self.special_token_ids or "unset"}>'
873893

874894

875895
# Example usage:

gguf-py/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "gguf"
3-
version = "0.3.3"
3+
version = "0.4.0"
44
description = "Write ML models in GGUF for GGML"
55
authors = ["GGML <ggml@ggml.ai>"]
66
packages = [

0 commit comments

Comments
 (0)