Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various script cleanups/fixes + convert merges and special token handling #2842

Merged
merged 17 commits into from
Aug 30, 2023
Merged
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 75 additions & 93 deletions convert-falcon-hf-to-gguf.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import json
import numpy as np
import torch
import argparse

from typing import Any, List
from pathlib import Path
@@ -32,11 +33,10 @@ def bytes_to_unicode():
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
return dict(zip(bs, (chr(n) for n in cs)))


def count_model_parts(dir_model: str) -> int:
def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
@@ -47,16 +47,21 @@ def count_model_parts(dir_model: str) -> int:
return num_parts


if len(sys.argv) < 3:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a Falcon model to a GGML compatible file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()

args = parse_args()

# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
@@ -65,25 +70,21 @@ def count_model_parts(dir_model: str) -> int:
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))

sys.exit(1)

fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+last_dir)
print("gguf: loading model "+dir_model.name)

with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "RWForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit()
sys.exit(1)

# get number of model parts
num_parts = count_model_parts(dir_model)
@@ -113,77 +114,58 @@ def count_model_parts(dir_model: str) -> int:

print("gguf: get tokenizer metadata")

tokens: List[str] = []
tokens: List[bytearray] = []
scores: List[float] = []
toktypes: List[int] = []
merges: List[str] = []


if Path(dir_model + "/tokenizer.json").is_file():
# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

print("gguf: get gpt2 tokenizer merges")

with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
merges = tokenizer_json["model"]["merges"]

gguf_writer.add_token_merges(merges)

print("gguf: get gpt2 tokenizer vocab")

vocab_size = len(tokenizer_json["model"]["vocab"])

# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)

for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)

gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
print("gguf: get gpt2 tokenizer vocab")

print("gguf: get special token ids")
# Look for special tokens in config.json
vocab_size = len(tokenizer_json["model"]["vocab"])

if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)

if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy

if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

@@ -199,15 +181,17 @@ def count_model_parts(dir_model: str) -> int:
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = ("pytorch_model.bin",)
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
model_part = torch.load(dir_model / part_name, map_location="cpu")

for name in model_part.keys():
data = model_part[name]
@@ -238,11 +222,8 @@ def count_model_parts(dir_model: str) -> int:
data = data.squeeze().numpy()

# map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map:
name = tensor_map[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()

@@ -261,19 +242,20 @@ def count_model_parts(dir_model: str) -> int:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(name, data)
gguf_writer.add_tensor(new_name, data)


print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print("gguf: model successfully exported to '" + fname_out + "'")
print(f"gguf: model successfully exported to '{fname_out}'")
print("")
173 changes: 70 additions & 103 deletions convert-gptneox-hf-to-gguf.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import json
import numpy as np
import torch
import argparse

from typing import Any, List
from pathlib import Path
@@ -34,11 +35,10 @@ def bytes_to_unicode():
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
return dict(zip(bs, (chr(n) for n in cs)))


def count_model_parts(dir_model: str) -> int:
def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
@@ -49,16 +49,21 @@ def count_model_parts(dir_model: str) -> int:
return num_parts


if len(sys.argv) < 3:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")
sys.exit(1)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a GPT-NeoX model to a GGML compatible file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()

args = parse_args()

# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
@@ -67,19 +72,15 @@ def count_model_parts(dir_model: str) -> int:
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))

sys.exit(1)

fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+last_dir)
print("gguf: loading model "+dir_model.name)

with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "GPTNeoXForCausalLM":
@@ -97,7 +98,7 @@ def count_model_parts(dir_model: str) -> int:

block_count = hparams["num_hidden_layers"]

gguf_writer.add_name(last_dir)
gguf_writer.add_name(dir_model.name)
gguf_writer.add_context_length(hparams["max_position_embeddings"])
gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_block_count(block_count)
@@ -111,86 +112,52 @@ def count_model_parts(dir_model: str) -> int:

print("gguf: get tokenizer metadata")

tokens: List[str] = []
merges: List[str] = []


if Path(dir_model + "/tokenizer.json").is_file():
# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

print("gguf: get gpt2 tokenizer merges")
tokens: List[bytearray] = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes shut up the type warnings but I'm not 100% sure they're correct. The alternative would be to leave it List[str] and then convert the token text to a string. I assume it's already UTF-8 bytes so there probably isn't a functional difference.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The types are correct this way, because when we get to tokens.append(text), 'text' is explicitly a bytearray.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The types are correct this way, because when we get to tokens.append(text), 'text' is explicitly a bytearray.

Right. I meant my change fixes the type warning but the part I wasn't sure about was whether text actually is supposed to be a bytearray there and what is supposed to get submitted to gguf to write out the tokens. The type annotation for add_token_list method is also just List and doesn't specify what the element is supposed to be.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decision was made to accept several types as input to the token list depending on type of he tokenizer output (spm vs bpe)
https://github.com/ggerganov/llama.cpp/blob/dd0dc366dab10e8df28d3924e7f313b5c695e908/gguf-py/gguf/gguf.py#L373-L375

Copy link
Collaborator Author

@KerfuffleV2 KerfuffleV2 Aug 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decision was made to accept several types as input to the token list

So we'd want

def add_token_list(self, tokens: Union[List[str], List[bytes], List[bytearray]]):

Correct? Or would you want to allow the items to be non-homogenous like List[Union[str, bytes, bytearray]]?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the differences of the python types str, bytes and bytearray? If they all resolve to the same when written to disk then any of them could be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it as simple it could be as long there is no difference in how the tokens are written to disk.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str is unicode text which will be encoded as utf-8 before being written to disk. bytes and bytearray are binary data. Those two are subclasses of ByteString, but we can't use that because it also allows memoryview. The least repetitive way to write this would be to use a TypeVar:

StrOrBytes = TypeVar('StrOrBytes', str, bytes, bytearray)
# ...
    def add_token_list(self, tokens: Sequence[StrOrBytes]):
        # ...
    def add_token_merges(self, merges: Sequence[StrOrBytes]):
        # ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should merges actually support all of those?


with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
merges = tokenizer_json["model"]["merges"]

gguf_writer.add_token_merges(merges)

print("gguf: get gpt2 tokenizer vocab")

vocab_size = len(tokenizer_json["model"]["vocab"])

# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)

tokens.append(text)
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)

gguf_writer.add_token_list(tokens)
# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")

if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
print("gguf: get special token ids")
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)

with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
print("gguf: get gpt2 tokenizer vocab")

# find special token ids
vocab_size = len(tokenizer_json["model"]["vocab"])

if "bos_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]:
gguf_writer.add_bos_token_id(key["id"])
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)

if "eos_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]:
gguf_writer.add_eos_token_id(key["id"])
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}

if "unk_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]:
gguf_writer.add_unk_token_id(key["id"])
for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)

if "sep_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]:
gguf_writer.add_sep_token_id(key["id"])
tokens.append(text)

if "pad_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]:
gguf_writer.add_pad_token_id(key["id"])
gguf_writer.add_token_list(tokens)

special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

@@ -200,13 +167,15 @@ def count_model_parts(dir_model: str) -> int:
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = ("pytorch_model.bin",)
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")

@@ -226,11 +195,8 @@ def count_model_parts(dir_model: str) -> int:
data = data.squeeze().numpy()

# map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map:
name = tensor_map[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()

@@ -249,19 +215,20 @@ def count_model_parts(dir_model: str) -> int:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(name, data)
gguf_writer.add_tensor(new_name, data)


print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()

print("gguf: model successfully exported to '" + fname_out + "'")
print(f"gguf: model successfully exported to '{fname_out}'")
print("")
200 changes: 75 additions & 125 deletions convert-llama-7b-pth-to-gguf.py
Original file line number Diff line number Diff line change
@@ -10,8 +10,9 @@
import json
import numpy as np
import torch
import argparse

from typing import Any, List
from typing import Any, List, TypeAlias
from pathlib import Path
from sentencepiece import SentencePieceProcessor

@@ -20,7 +21,7 @@
NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'


def count_model_parts(dir_model: str) -> int:
def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("consolidated."):
@@ -31,18 +32,21 @@ def count_model_parts(dir_model: str) -> int:
return num_parts


if len(sys.argv) < 3:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")

sys.exit(1)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a PyTorch 7B LLaMA model to a GGML compatible file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()

args = parse_args()

# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))

dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
@@ -51,19 +55,15 @@ def count_model_parts(dir_model: str) -> int:
# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))

sys.exit(1)

fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+last_dir)
print("gguf: loading model "+dir_model.name)

with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "LlamaForCausalLM":
@@ -107,7 +107,7 @@ def count_model_parts(dir_model: str) -> int:
sys.exit()


gguf_writer.add_name(last_dir)
gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length)
@@ -133,109 +133,60 @@ def count_model_parts(dir_model: str) -> int:
scores: List[float] = []
toktypes: List[int] = []

if Path(dir_model + "/tokenizer.model").is_file():
# vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab and scores")

tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model")

for i in range(tokenizer.vocab_size()):
text: bytes
score: float

piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)

toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3

# toktype = 4 is user-defined = tokens from added_tokens.json

if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6

tokens.append(text)
scores.append(score)
toktypes.append(toktype)

if Path(dir_model + "/added_tokens.json").is_file():
with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f:
addtokens_json = json.load(f)

print("gguf: get added tokens")

for key in addtokens_json:
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type

gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)


print("gguf: get special token ids")

if Path(dir_model + "/tokenizer.json").is_file():
# Look for special tokens in tokenizer.json if it exists

with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer = json.load(f)
tokenizer_model_file = dir_model / 'tokenizer.model'
if not tokenizer_model_file.is_file():
print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)

if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file():
# vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab and scores")

with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
tokenizer = SentencePieceProcessor(str(tokenizer_model_file))

if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]["content"]:
gguf_writer.add_bos_token_id(key["id"])
for i in range(tokenizer.vocab_size()):
text: bytes
score: float

if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]["content"]:
gguf_writer.add_eos_token_id(key["id"])
piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)

if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]["content"]:
gguf_writer.add_unk_token_id(key["id"])
toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3

if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]["content"]:
gguf_writer.add_sep_token_id(key["id"])
# toktype = 4 is user-defined = tokens from added_tokens.json

if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]["content"]:
gguf_writer.add_pad_token_id(key["id"])
else:
# If no tokenizer.json: Look for special tokens in config.json
if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6

if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
tokens.append(text)
scores.append(score)
toktypes.append(toktype)

if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)

if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
print("gguf: get added tokens")

if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
for key in addtokens_json:
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type

if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

@@ -247,6 +198,8 @@ def count_model_parts(dir_model: str) -> int:
part_names = (f"consolidated.{n:02}.pth" for n in range(0, num_parts))

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")

@@ -266,11 +219,8 @@ def count_model_parts(dir_model: str) -> int:
data = data.squeeze().numpy()

# map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map:
name = tensor_map[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()

@@ -289,20 +239,20 @@ def count_model_parts(dir_model: str) -> int:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(name, data)
gguf_writer.add_tensor(new_name, data)


print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()


print("gguf: model successfully exported to '" + fname_out + "'")
print(f"gguf: model successfully exported to '{fname_out}'")
print("")
28 changes: 14 additions & 14 deletions convert-llama-ggmlv3-to-gguf.py
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@ def __init__(self):
self.dims = ()
self.dtype = None
self.start_offset = 0
self.len_bytes = 0
self.len_bytes = np.int64(0)

def load(self, data, offset):
orig_offset = offset
@@ -134,13 +134,14 @@ def load(self, data, offset):
return offset

class GGMLToGGUF:
def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override = None):
def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override = None, special_vocab = None):
hp = ggml_model.hyperparameters
self.model = ggml_model
self.data = data
self.cfg = cfg
self.params_override = params_override
self.vocab_override = vocab_override
self.special_vocab = special_vocab
if params_override is not None:
n_kv_head = params_override.n_head_kv
else:
@@ -162,6 +163,8 @@ def save(self):
gguf_writer = gguf.GGUFWriter(self.cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], use_temp_file = False)
self.add_params(gguf_writer)
self.add_vocab(gguf_writer)
if self.special_vocab is not None:
self.special_vocab.add_to_gguf(gguf_writer)
self.add_tensors(gguf_writer)
print(" gguf: write header")
gguf_writer.write_header_to_file()
@@ -259,20 +262,13 @@ def add_vocab(self, gguf_writer):
gguf_writer.add_eos_token_id(2)

def add_tensors(self, gguf_writer):
nm = self.name_map
tensor_map = self.name_map
data = self.data
print(f'* Adding {len(self.model.tensors)} tensor(s)')
for tensor in self.model.tensors:
name = str(tensor.name, 'UTF-8')
if name.endswith('.weight'):
name = name[:-7]
suffix = '.weight'
elif name.endswith('.bias'):
name = name[:-5]
suffix = '.bias'
mapped_name = nm.get(name)
mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
assert mapped_name is not None, f'Bad name {name}'
mapped_name += suffix
tempdims = list(tensor.dims[:])
if len(tempdims) > 1:
temp = tempdims[1]
@@ -302,8 +298,10 @@ def handle_metadata(cfg, hp):
else:
raise ValueError('Unable to load metadata')
vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype)
# FIXME: Respect cfg.vocab_dir?
svocab = gguf.SpecialVocab(cfg.model_metadata_dir)
convert.check_vocab_size(params, vocab)
return (params, vocab)
return (params, vocab, svocab)

def handle_args():
parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF')
@@ -330,14 +328,16 @@ def main():
print(f'* GGML model hyperparameters: {model.hyperparameters}')
vocab_override = None
params_override = None
special_vocab = None
if cfg.model_metadata_dir is not None:
(params_override, vocab_override) = handle_metadata(cfg, model.hyperparameters)
(params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters)
print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.')
print(f'* Overriding params: {params_override}')
print(f'* Overriding vocab: {vocab_override}')
print(f'* Special vocab: {special_vocab}')
else:
print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n')
converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override)
converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override, special_vocab = special_vocab)
converter.save()
print(f'* Successful completion. Output saved to: {cfg.output}')

201 changes: 75 additions & 126 deletions convert-llama-hf-to-gguf.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,9 @@
import json
import numpy as np
import torch
import argparse

from typing import Any, List, Optional
from typing import Any, List, Optional, TypeAlias
from pathlib import Path
from sentencepiece import SentencePieceProcessor

@@ -43,40 +44,38 @@ def count_model_parts(dir_model: str) -> int:
return num_parts


if len(sys.argv) < 3:
print(f"Usage: python {sys.argv[0]} dir-model ftype\n")
print(" ftype == 0 -> float32")
print(" ftype == 1 -> float16")

sys.exit(1)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
return parser.parse_args()

args = parse_args()

# output in the same directory as the model
dir_model = sys.argv[1]
last_dir = os.path.basename(os.path.normpath(dir_model))

dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)

# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16


# map from ftype to string
ftype_str = ["f32", "f16"]

ftype = 1
if len(sys.argv) > 2:
ftype = int(sys.argv[2])
if ftype < 0 or ftype > 1:
print("Invalid ftype: " + str(ftype))

sys.exit(1)

fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'

print("gguf: loading model "+last_dir)
print("gguf: loading model "+dir_model.name)

with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "LlamaForCausalLM":
@@ -115,7 +114,7 @@ def count_model_parts(dir_model: str) -> int:
sys.exit()


gguf_writer.add_name(last_dir)
gguf_writer.add_name(dir_model.name)
gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout("Meta AI original pth")
gguf_writer.add_context_length(ctx_length)
@@ -141,110 +140,61 @@ def count_model_parts(dir_model: str) -> int:
scores: List[float] = []
toktypes: List[int] = []

if Path(dir_model + "/tokenizer.model").is_file():
# vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab, scores and token types")

tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model")

for i in range(tokenizer.vocab_size()):
text: bytes
score: float

piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)

toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3

# toktype = 4 is user-defined = tokens from added_tokens.json

if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6

tokens.append(text)
scores.append(score)
toktypes.append(toktype)

if Path(dir_model + "/added_tokens.json").is_file():
with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f:
addtokens_json = json.load(f)

print("gguf: get added tokens")

for key in addtokens_json:
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type


gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)


print("gguf: get special token ids")

if Path(dir_model + "/tokenizer.json").is_file():
# Look for special tokens in tokenizer.json if it exists
tokenizer_model_file = dir_model / 'tokenizer.model'
if not tokenizer_model_file.is_file():
print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)

with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
tokenizer = json.load(f)
# vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab, scores and token types")

if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file():
tokenizer = SentencePieceProcessor(str(tokenizer_model_file))

with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = json.load(f)
for i in range(tokenizer.vocab_size()):
text: bytes
score: float

if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["bos_token"]["content"]:
gguf_writer.add_bos_token_id(key["id"])
piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)

if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["eos_token"]["content"]:
gguf_writer.add_eos_token_id(key["id"])
toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3

if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["unk_token"]["content"]:
gguf_writer.add_unk_token_id(key["id"])
# toktype = 4 is user-defined = tokens from added_tokens.json

if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]["content"]:
gguf_writer.add_sep_token_id(key["id"])
if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6

if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None:
for key in tokenizer["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]["content"]:
gguf_writer.add_pad_token_id(key["id"])
else:
# If no tokenizer.json: Look for special tokens in config.json
tokens.append(text)
scores.append(score)
toktypes.append(toktype)

if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
added_tokens_file = dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)

if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
print("gguf: get added tokens")

if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
for key in addtokens_json:
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type

if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
gguf_writer.add_sep_token_id(hparams["sep_token_id"])

if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
gguf_writer.add_tokenizer_model("llama")
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(dir_model)
special_vocab.add_to_gguf(gguf_writer)

# TENSORS

@@ -254,13 +204,15 @@ def count_model_parts(dir_model: str) -> int:
print("gguf: get tensor metadata")

if num_parts == 0:
part_names = ("pytorch_model.bin",)
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)

for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")

@@ -286,11 +238,8 @@ def count_model_parts(dir_model: str) -> int:
data = reverse_hf_permute(data, head_count, head_count_kv)

# map tensor names
if name.endswith(".weight") and name[:-7] in tensor_map:
name = tensor_map[name[:-7]] + ".weight"
elif name.endswith(".bias") and name[:-5] in tensor_map:
name = tensor_map[name[:-5]] + ".bias"
else:
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()

@@ -309,20 +258,20 @@ def count_model_parts(dir_model: str) -> int:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.add_tensor(name, data)
gguf_writer.add_tensor(new_name, data)


print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()

gguf_writer.close()


print("gguf: model successfully exported to '" + fname_out + "'")
print(f"gguf: model successfully exported to '{fname_out}'")
print("")
6 changes: 3 additions & 3 deletions convert-lora-to-ggml.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import re
import struct
import sys
from typing import Any, Dict, Sequence, TextIO
from typing import Any, Dict, Sequence, BinaryIO

import numpy as np
import torch
@@ -46,7 +46,7 @@ def translate_tensor_name(t: str) -> str:
sys.exit(1)


def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
def write_file_header(fout: BinaryIO, params: Dict[str, Any]) -> None:
fout.write(b"ggla"[::-1]) # magic (ggml lora)
fout.write(struct.pack("i", 1)) # file version
fout.write(struct.pack("i", params["r"]))
@@ -60,7 +60,7 @@ def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:


def write_tensor_header(
self, name: str, shape: Sequence[int], data_type: np.dtype
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
) -> None:
sname = name.encode("utf-8")
fout.write(
148 changes: 85 additions & 63 deletions convert.py

Large diffs are not rendered by default.

570 changes: 339 additions & 231 deletions gguf-py/gguf/gguf.py

Large diffs are not rendered by default.

Empty file added gguf-py/gguf/py.typed
Empty file.
1 change: 1 addition & 0 deletions gguf-py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [
{include = "gguf"},
{include = "gguf/py.typed"},
]
readme = "README.md"
homepage = "https://ggml.ai"