Skip to content

Commit

Permalink
Add GGUF for Mamba (#34200)
Browse files Browse the repository at this point in the history
* add mamba architecture for gguf

* add logic for weights conversion, some fixes and refactoring

* add lm_head layers, unit test refactoring

* more fixes for tests

* remove lm_head creation

* remove unused comments
  • Loading branch information
VladOS95-cyber authored Oct 30, 2024
1 parent eab6c49 commit 5251fe6
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ For now the supported model architectures are the architectures that have been v
- GPT2
- Starcoder2
- T5
- Mamba

## Example usage

Expand Down
25 changes: 25 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,19 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"mamba": {
"token_embd": "backbone.embeddings",
"blk": "backbone.layers",
"ssm_a": "mixer.A_log",
"ssm_conv1d": "mixer.conv1d",
"ssm_in": "mixer.in_proj",
"ssm_out": "mixer.out_proj",
"ssm_x": "mixer.x_proj",
"ssm_dt": "mixer.dt_proj",
"attn_norm": "norm",
"output_norm": "backbone.norm_f",
"output.weight": "lm_head.weight",
},
}


Expand Down Expand Up @@ -373,6 +386,17 @@
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_epsilon": "norm_epsilon",
},
"mamba": {
"vocab_size": "vocab_size",
"context_length": "max_position_embeddings",
"embedding_length": "hidden_size",
"attention.layer_norm_rms_epsilon": "layer_norm_epsilon",
"block_count": "num_hidden_layers",
"ssm.conv_kernel": "conv_kernel",
"ssm.state_size": "state_size",
"ssm.time_step_rank": "time_step_rank",
"ssm.inner_size": "intermediate_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -768,6 +792,7 @@ def converted(self) -> Tokenizer:
"gpt2": GGUFGPTConverter,
"starcoder2": GGUFGPTConverter,
"t5": GGUFT5Converter,
"mamba": GGUFGPTConverter,
}


Expand Down
13 changes: 13 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
name = "lm_head.weight"
parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
continue
if architecture == "mamba":
if "ssm_d" in name and "bias" not in name and "weight" not in name:
# ssm_d has conflicts with ssm_dt in name checking
# we have to explicitly check that name is exactly ssm_d
name = name.replace("ssm_d", "mixer.D")
if "ssm_conv1d.weight" in name:
# for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
# quantized one is (5120, 4)
weights = np.expand_dims(weights, axis=1)
if "ssm_a" in name:
# Original exponential implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
weights = np.log(-weights)

for tensor_name in tensor_key_mapping:
if tensor_name.format(bid=bid) in name:
Expand Down
56 changes: 54 additions & 2 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class GgufIntegrationTests(unittest.TestCase):
starcoder2_model_id = "QuantFactory/starcoder2-3b-GGUF"
starcoder2_fp16_model_id = "brittlewis12/starcoder2-3b-GGUF"
starcoder2_original_model_id = "bigcode/starcoder2-3b"
mamba_original_model_id = "state-spaces/mamba-2.8b-hf"
mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -102,6 +104,8 @@ class GgufIntegrationTests(unittest.TestCase):
q6_k_gpt2_xl_model_id = "gpt2-xl.Q6_K.gguf"
q6_k_starcoder2_model_id = "starcoder2-3b.Q6_K.gguf"
fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf"
q6_k_mamba_model_id = "ggml-model-Q6_K.gguf"
fp16_mamba_model_id = "ggml-model-f16.gguf"

example_text = "Hello"

Expand Down Expand Up @@ -573,6 +577,8 @@ def test_gpt2_weights_conversion_fp16(self):
if layer_name in quantized_state_dict:
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_gpt2_xl_Q6_K(self):
tokenizer = AutoTokenizer.from_pretrained(self.gpt2_xl_model_id, gguf_file=self.q6_k_gpt2_xl_model_id)
Expand Down Expand Up @@ -639,6 +645,8 @@ def test_falcon7b_weights_conversion_fp16(self):
if layer_name in quantized_state_dict:
self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, quantized_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_stablelm_q4_k_m(self):
model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -708,6 +716,8 @@ def test_stablelm_weights_conversion_fp16(self):
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_starcoder2_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -727,10 +737,11 @@ def test_starcoder2_weights_conversion_fp16(self):
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict and layer_name != "lm_head.weight":
# quantized models do not contain "lm_head.weight" layer
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_starcoder2_q6_k(self):
example_function_text = "def print_hello_world():"
Expand All @@ -748,6 +759,47 @@ def test_starcoder2_q6_k(self):
EXPECTED_TEXT = 'def print_hello_world():\n print("Hello World")\n\ndef print'
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_mamba_weights_conversion_fp16(self):
original_model = AutoModelForCausalLM.from_pretrained(
self.mamba_original_model_id,
torch_dtype=torch.float16,
)

converted_model = AutoModelForCausalLM.from_pretrained(
self.mamba_model_id,
gguf_file=self.fp16_mamba_model_id,
torch_dtype=torch.float16,
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
if "mixer.A_log" in layer_name:
# we should increase tolerance after exponential reversing
# and performing np.log(-weights) operation as numbers are slightly different
torch.testing.assert_close(original_params, converted_state_dict[layer_name], atol=1e-3, rtol=1e-3)
else:
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_mamba_q6_k(self):
model = AutoModelForCausalLM.from_pretrained(
self.mamba_model_id,
gguf_file=self.q6_k_mamba_model_id,
torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(self.mamba_model_id, gguf_file=self.q6_k_mamba_model_id)
text = tokenizer(self.example_text, return_tensors="pt")["input_ids"]
out = model.generate(text, max_new_tokens=10)

EXPECTED_TEXT = "Hello,I answerthe question.\n\nA"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down

0 comments on commit 5251fe6

Please sign in to comment.