Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Qwen2 GGUF loading support #31175

Merged
merged 9 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ For now the supported model architectures are the architectures that have been v

- LLaMa
- Mistral
- Qwen2

## Example usage

Expand Down
8 changes: 5 additions & 3 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,11 @@ def converted(self) -> Tokenizer:


class Qwen2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
if not vocab:
vocab = self.original_tokenizer.encoder
if not merges:
merges = list(self.original_tokenizer.bpe_ranks.keys())

tokenizer = Tokenizer(
BPE(
Expand Down
64 changes: 56 additions & 8 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tokenizers.models import BPE

from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
from ..utils import logging
from ..utils.logging import tqdm

Expand Down Expand Up @@ -101,6 +101,21 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"qwen2": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -133,8 +148,19 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"qwen2": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.model": "model_type",
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
"ggml.unknown_token_id": "unk_token_id",
Expand Down Expand Up @@ -490,14 +516,15 @@ def __init__(self, dict_):
for k, v in dict_.items():
setattr(self, k, v)

if not hasattr(self, "tokens") or not hasattr(self, "scores"):
raise ValueError("tokens and scores need to be passed for a LLaMa tokenizer to be instantiated.")
else:
if not hasattr(self, "merges"):
if not hasattr(self, "tokens") or not hasattr(self, "scores"):
raise ValueError(
"tokens and scores need to be passed for a LLaMa tokenizer without merges to be instantiated."
)
tokens = self.tokens
scores = self.scores
vocab = {t: scores[i] for i, t in enumerate(tokens)}

if not hasattr(self, "merges"):
logger.warning("Merges were not in checkpoint, building merges on the fly.")
merges = []
for merge, piece_score in tqdm(vocab.items()):
Expand Down Expand Up @@ -562,16 +589,37 @@ def decoder(self, replacement, add_prefix_space):
return decoders.Sequence(sequence)


class GGUFQwen2Converter(Qwen2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
merges = self.original_tokenizer.merges
tokenizer = super().converted(vocab, merges)

tokenizer.add_special_tokens(
[
AddedToken("<|endoftext|>", normalized=False, special=True),
AddedToken("<|im_start|>", normalized=False, special=True),
AddedToken("<|im_end|>", normalized=False, special=True),
]
)
return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
}


def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer:
def convert_gguf_tokenizer(architecture, tokenizer_dict) -> Tokenizer:
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.

Args:
architecture (`str`): The model architecture derived from gguf file.
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
Instance of a slow tokenizer to convert in the backend tokenizer for
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
Expand All @@ -580,6 +628,6 @@ def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer:
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""
tokenizer_class_name = tokenizer_dict["tokenizer_type"]
tokenizer_class_name = architecture
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(tokenizer_dict).converted()
4 changes: 2 additions & 2 deletions src/transformers/models/qwen2/tokenization_qwen2_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __init__(
)

super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/tokenization_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ def __init__(self, *args, **kwargs):
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
elif gguf_file is not None:
# We need to convert a slow tokenizer to build the backend
tokenizer_dict = load_gguf_checkpoint(kwargs.get("vocab_file"))["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(tokenizer_dict)
gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
architecture = gguf_param["config"]["model_type"]
tokenizer_dict = gguf_param["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(architecture, tokenizer_dict)
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
elif self.slow_tokenizer_class is not None:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
Expand Down
14 changes: 14 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class GgufIntegrationTests(unittest.TestCase):
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"

q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
Expand All @@ -41,6 +42,7 @@ class GgufIntegrationTests(unittest.TestCase):
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"

q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -157,6 +159,18 @@ def test_mistral_q4_0(self):
EXPECTED_TEXT = "Hello,\n\nI'm trying to create a"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_qwen2_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down