diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 890ca042488154..9f45d2ca4cf5d4 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -82,6 +82,7 @@ For now the supported model architectures are the architectures that have been v - Phi3 - Bloom - Falcon +- StableLM ## Example usage diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index ca39b5ef5f917a..997fa6a4ec2435 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -143,6 +143,21 @@ ".output.": ".lm_head.", "output_norm": "ln_f", }, + "stablelm": { + "token_embd": "model.embed_tokens", + "blk": "model.layers", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_v": "self_attn.v_proj", + "attn_k": "self_attn.k_proj", + "attn_output": "self_attn.o_proj", + "output.weight": "lm_head.weight", + "output_norm": "model.norm", + }, } @@ -238,6 +253,17 @@ "vocab_size": "vocab_size", "attention.layer_norm_epsilon": "layer_norm_epsilon", }, + "stablelm": { + "context_length": "max_position_embeddings", + "block_count": "num_hidden_layers", + "feed_forward_length": "intermediate_size", + "embedding_length": "hidden_size", + "rope.dimension_count": None, + "attention.head_count": "num_attention_heads", + "attention.head_count_kv": "num_key_value_heads", + "attention.layer_norm_epsilon": "layer_norm_eps", + "vocab_size": "vocab_size", + }, } GGUF_TOKENIZER_MAPPING = { @@ -547,7 +573,7 @@ def converted(self) -> Tokenizer: return tokenizer -class GGUFBloomConverter(GPT2Converter): +class GGUFGPTConverter(GPT2Converter): def __init__(self, tokenizer_dict): self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict) self.additional_kwargs = {} @@ -564,8 +590,9 @@ def converted(self) -> Tokenizer: "qwen2": GGUFQwen2Converter, "qwen2_moe": GGUFQwen2Converter, "phi3": GGUFPhi3Converter, - "bloom": GGUFBloomConverter, - "falcon": GGUFBloomConverter, + "bloom": GGUFGPTConverter, + "falcon": GGUFGPTConverter, + "stablelm": GGUFGPTConverter, } diff --git a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py index c79e6d9ada15d3..7fafa440d05113 100644 --- a/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py +++ b/src/transformers/models/gpt_neox/tokenization_gpt_neox_fast.py @@ -105,8 +105,8 @@ def __init__( **kwargs, ): super().__init__( - vocab_file, - merges_file, + vocab_file=vocab_file, + merges_file=merges_file, tokenizer_file=tokenizer_file, unk_token=unk_token, bos_token=bos_token, diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index ddc6288f36dd31..bf207111cf1932 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -47,6 +47,9 @@ class GgufIntegrationTests(unittest.TestCase): falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf" falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf" original_flacon7b_model_id = "tiiuae/falcon-7b" + stablelm_model_id = "afrideva/stablelm-3b-4e1t-GGUF" + stablelm2_model_id = "afrideva/stablelm-2-1_6b-GGUF" + original_stablelm2_model_id = "stabilityai/stablelm-2-1_6b" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -58,6 +61,7 @@ class GgufIntegrationTests(unittest.TestCase): q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf" q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf" + q4_k_m_stablelm_model_id = "stablelm-3b-4e1t.q4_k_m.gguf" # imatrix iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf" iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf" @@ -75,6 +79,7 @@ class GgufIntegrationTests(unittest.TestCase): q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf" q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" fp16_bloom_model_id = "bloom-560m.fp16.gguf" + fp16_stablelm2_model_id = "stablelm-2-1_6b.fp16.gguf" q8_bloom_model_id = "bloom-560m.q8_0.gguf" f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf" q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf" @@ -503,6 +508,75 @@ def test_falcon7b_weights_conversion_fp16(self): self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape) torch.testing.assert_close(original_params, quantized_state_dict[layer_name]) + def test_stablelm_q4_k_m(self): + model = AutoModelForCausalLM.from_pretrained( + self.stablelm_model_id, + gguf_file=self.q4_k_m_stablelm_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.stablelm_model_id, gguf_file=self.q4_k_m_stablelm_model_id) + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello-\nI am trying to create a new user" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_stablelm_fp16(self): + original_model = AutoModelForCausalLM.from_pretrained( + self.original_stablelm2_model_id, + torch_dtype=torch.float16, + ) + + converted_model = AutoModelForCausalLM.from_pretrained( + self.stablelm2_model_id, + gguf_file=self.fp16_stablelm2_model_id, + torch_dtype=torch.float16, + # for precise comparison it is required to use the original model config + # as quantized one is different in parameters: use_parallel_residual and use_qkv_bias + # and it highly influences on the output results + config=original_model.config, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.stablelm2_model_id, gguf_file=self.fp16_stablelm2_model_id) + text = tokenizer(self.example_text, return_tensors="pt") + original_out = original_model.generate(**text, max_new_tokens=10) + converted_out = converted_model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I am a 20 year old male" + self.assertEqual(tokenizer.decode(converted_out[0], skip_special_tokens=True), EXPECTED_TEXT) + self.assertEqual( + tokenizer.decode(converted_out[0], skip_special_tokens=True), + tokenizer.decode(original_out[0], skip_special_tokens=True), + ) + + def test_stablelm_weights_conversion_fp16(self): + original_model = AutoModelForCausalLM.from_pretrained( + self.original_stablelm2_model_id, + device_map="auto", + torch_dtype=torch.float16, + ) + + converted_model = AutoModelForCausalLM.from_pretrained( + self.stablelm2_model_id, + gguf_file=self.fp16_stablelm2_model_id, + device_map="auto", + torch_dtype=torch.float16, + # for precise comparison it is required to use the original model config + # as quantized one is different in parameters: use_parallel_residual and use_qkv_bias + # and it highly influences on the output results + config=original_model.config, + ) + + converted_state_dict = converted_model.state_dict() + original_state_dict = original_model.state_dict() + + for layer_name, original_params in original_state_dict.items(): + if layer_name in converted_state_dict: + self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape) + torch.testing.assert_close(original_params, converted_state_dict[layer_name]) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset