Skip to content

Fix ChatGLMModel for glm-4-9b cannot find tokenizer merges in model file #13058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5022,18 +5022,73 @@ def set_vocab(self):

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"])
vocab_size = hparams.get("padded_vocab_size", hparams.get("vocab_size"))
assert max(tokenizer.get_vocab().values()) < vocab_size

tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
tokpre = self.get_vocab_base_pre(tokenizer)

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()}
added_vocab = tokenizer.get_added_vocab()

added_tokens_decoder = tokenizer.added_tokens_decoder

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")

if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)

else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)

self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
try:
tokenizer_file = self.dir_model / 'tokenizer.json'
if not tokenizer_file.is_file():
raise ValueError("tokenizer.json not found")

# for https://huggingface.co/THUDM/glm-4-9b
special_vocab=gguf.SpecialVocab(
self.dir_model,
load_merges=True,
n_vocab=vocab_size
)

self.gguf_writer.add_tokenizer_model("gpt2")

except Exception as e:
logger.warning(f'Failed to load special tokens: {e}')
# for https://huggingface.co/THUDM/glm-4-9b-hf
special_vocab=gguf.SpecialVocab(
self.dir_model,
load_merges=False,
n_vocab=vocab_size
)
self.gguf_writer.add_tokenizer_model("llama")

# only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])

#TODO In llama.cpp, special tokens are mapped one-to-one between a token and a coordinate. However, in reality, a transformer might associate a special token like eos_token_id with multiple tokens.
# Currently, llama.cpp only supports a one-to-one mapping.
# This can lead to an issue where the model fails to terminate properly.
# You can see a temporary workaround here. https://github.com/ggml-org/llama.cpp/issues/9606
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
# this one is usually not in config.json anyway
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
Expand All @@ -5045,7 +5100,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed)))
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"]))
self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")))
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5))
Expand Down