Skip to content

Commit

Permalink
Add gguf support for StableLM (#33793)
Browse files Browse the repository at this point in the history
* add stablelm gguf architecture support

* add additional quantization tests

* resolve merge conflict, add weight conversion tests for fp16
  • Loading branch information
VladOS95-cyber authored Oct 9, 2024
1 parent e783f12 commit faa0f63
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ For now the supported model architectures are the architectures that have been v
- Phi3
- Bloom
- Falcon
- StableLM

## Example usage

Expand Down
33 changes: 30 additions & 3 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,21 @@
".output.": ".lm_head.",
"output_norm": "ln_f",
},
"stablelm": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}


Expand Down Expand Up @@ -245,6 +260,17 @@
"vocab_size": "vocab_size",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
"stablelm": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_epsilon": "layer_norm_eps",
"vocab_size": "vocab_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -554,7 +580,7 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFBloomConverter(GPT2Converter):
class GGUFGPTConverter(GPT2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}
Expand All @@ -571,8 +597,9 @@ def converted(self) -> Tokenizer:
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"bloom": GGUFBloomConverter,
"falcon": GGUFBloomConverter,
"bloom": GGUFGPTConverter,
"falcon": GGUFGPTConverter,
"stablelm": GGUFGPTConverter,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def __init__(
**kwargs,
):
super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
Expand Down
74 changes: 74 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class GgufIntegrationTests(unittest.TestCase):
falcon7b_model_id = "xaviviro/falcon-7b-quantized-gguf"
falcon40b_model_id = "maddes8cht/tiiuae-falcon-40b-gguf"
original_flacon7b_model_id = "tiiuae/falcon-7b"
stablelm_model_id = "afrideva/stablelm-3b-4e1t-GGUF"
stablelm2_model_id = "afrideva/stablelm-2-1_6b-GGUF"
original_stablelm2_model_id = "stabilityai/stablelm-2-1_6b"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand All @@ -59,6 +62,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
q4_k_m_stablelm_model_id = "stablelm-3b-4e1t.q4_k_m.gguf"
# imatrix
iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf"
iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf"
Expand All @@ -76,6 +80,7 @@ class GgufIntegrationTests(unittest.TestCase):
q8_qwen2moe_model_id = "Qwen1.5-MoE-A2.7B_Q8_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
fp16_stablelm2_model_id = "stablelm-2-1_6b.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
q2_k_falcon7b_model_id = "falcon-7b-q2_k.gguf"
Expand Down Expand Up @@ -523,6 +528,75 @@ def test_falcon7b_weights_conversion_fp16(self):
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])

def test_stablelm_q4_k_m(self):
model = AutoModelForCausalLM.from_pretrained(
self.stablelm_model_id,
gguf_file=self.q4_k_m_stablelm_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(self.stablelm_model_id, gguf_file=self.q4_k_m_stablelm_model_id)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello-\nI am trying to create a new user"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_stablelm_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.original_stablelm2_model_id,
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.stablelm2_model_id,
gguf_file=self.fp16_stablelm2_model_id,
torch_dtype=torch.float16,
# for precise comparison it is required to use the original model config
# as quantized one is different in parameters: use_parallel_residual and use_qkv_bias
# and it highly influences on the output results
config=original_model.config,
)

tokenizer = AutoTokenizer.from_pretrained(self.stablelm2_model_id, gguf_file=self.fp16_stablelm2_model_id)
text = tokenizer(self.example_text, return_tensors="pt")
original_out = original_model.generate(**text, max_new_tokens=10)
converted_out = converted_model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am a 20 year old male"
self.assertEqual(tokenizer.decode(converted_out[0], skip_special_tokens=True), EXPECTED_TEXT)
self.assertEqual(
tokenizer.decode(converted_out[0], skip_special_tokens=True),
tokenizer.decode(original_out[0], skip_special_tokens=True),
)

def test_stablelm_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.original_stablelm2_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.stablelm2_model_id,
gguf_file=self.fp16_stablelm2_model_id,
device_map="auto",
torch_dtype=torch.float16,
# for precise comparison it is required to use the original model config
# as quantized one is different in parameters: use_parallel_residual and use_qkv_bias
# and it highly influences on the output results
config=original_model.config,
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit faa0f63

Please sign in to comment.