From 784af4cba9af8d1a0e2692cf43cf90df97e86751 Mon Sep 17 00:00:00 2001 From: Vladislav Bronzov <58587565+VladOS95-cyber@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:52:17 +0100 Subject: [PATCH] Add GGUF for Mamba (#34200) * add mamba architecture for gguf * add logic for weights conversion, some fixes and refactoring * add lm_head layers, unit test refactoring * more fixes for tests * remove lm_head creation * remove unused comments --- docs/source/en/gguf.md | 1 + src/transformers/integrations/ggml.py | 25 +++++++++ .../modeling_gguf_pytorch_utils.py | 13 +++++ tests/quantization/ggml/test_ggml.py | 56 ++++++++++++++++++- 4 files changed, 93 insertions(+), 2 deletions(-) diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 20531b990bc341..2da721b28986af 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -86,6 +86,7 @@ For now the supported model architectures are the architectures that have been v - GPT2 - Starcoder2 - T5 +- Mamba ## Example usage diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 4a2740fcb30e1c..f4545f2698c017 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -235,6 +235,19 @@ "output.weight": "lm_head.weight", "output_norm": "model.norm", }, + "mamba": { + "token_embd": "backbone.embeddings", + "blk": "backbone.layers", + "ssm_a": "mixer.A_log", + "ssm_conv1d": "mixer.conv1d", + "ssm_in": "mixer.in_proj", + "ssm_out": "mixer.out_proj", + "ssm_x": "mixer.x_proj", + "ssm_dt": "mixer.dt_proj", + "attn_norm": "norm", + "output_norm": "backbone.norm_f", + "output.weight": "lm_head.weight", + }, } @@ -373,6 +386,17 @@ "attention.head_count_kv": "num_key_value_heads", "attention.layer_norm_epsilon": "norm_epsilon", }, + "mamba": { + "vocab_size": "vocab_size", + "context_length": "max_position_embeddings", + "embedding_length": "hidden_size", + "attention.layer_norm_rms_epsilon": "layer_norm_epsilon", + "block_count": "num_hidden_layers", + "ssm.conv_kernel": "conv_kernel", + "ssm.state_size": "state_size", + "ssm.time_step_rank": "time_step_rank", + "ssm.inner_size": "intermediate_size", + }, } GGUF_TOKENIZER_MAPPING = { @@ -768,6 +792,7 @@ def converted(self) -> Tokenizer: "gpt2": GGUFGPTConverter, "starcoder2": GGUFGPTConverter, "t5": GGUFT5Converter, + "mamba": GGUFGPTConverter, } diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 171b2f4d15b122..c784ca0eb4ca2c 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -220,6 +220,19 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): name = "lm_head.weight" parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights)) continue + if architecture == "mamba": + if "ssm_d" in name and "bias" not in name and "weight" not in name: + # ssm_d has conflicts with ssm_dt in name checking + # we have to explicitly check that name is exactly ssm_d + name = name.replace("ssm_d", "mixer.D") + if "ssm_conv1d.weight" in name: + # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim, + # quantized one is (5120, 4) + weights = np.expand_dims(weights, axis=1) + if "ssm_a" in name: + # Original exponential implementation + # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977 + weights = np.log(-weights) for tensor_name in tensor_key_mapping: if tensor_name.format(bid=bid) in name: diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index ddc791e96a6489..da1af9bff8df90 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -59,6 +59,8 @@ class GgufIntegrationTests(unittest.TestCase): starcoder2_model_id = "QuantFactory/starcoder2-3b-GGUF" starcoder2_fp16_model_id = "brittlewis12/starcoder2-3b-GGUF" starcoder2_original_model_id = "bigcode/starcoder2-3b" + mamba_original_model_id = "state-spaces/mamba-2.8b-hf" + mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -102,6 +104,8 @@ class GgufIntegrationTests(unittest.TestCase): q6_k_gpt2_xl_model_id = "gpt2-xl.Q6_K.gguf" q6_k_starcoder2_model_id = "starcoder2-3b.Q6_K.gguf" fp16_starcoder2_gguf_model_id = "starcoder2-3b.fp16.gguf" + q6_k_mamba_model_id = "ggml-model-Q6_K.gguf" + fp16_mamba_model_id = "ggml-model-f16.gguf" example_text = "Hello" @@ -573,6 +577,8 @@ def test_gpt2_weights_conversion_fp16(self): if layer_name in quantized_state_dict: self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape) torch.testing.assert_close(original_params, quantized_state_dict[layer_name]) + else: + raise ValueError(f"Layer {layer_name} is not presented in GGUF model") def test_gpt2_xl_Q6_K(self): tokenizer = AutoTokenizer.from_pretrained(self.gpt2_xl_model_id, gguf_file=self.q6_k_gpt2_xl_model_id) @@ -639,6 +645,8 @@ def test_falcon7b_weights_conversion_fp16(self): if layer_name in quantized_state_dict: self.assertTrue(original_params.shape == quantized_state_dict[layer_name].shape) torch.testing.assert_close(original_params, quantized_state_dict[layer_name]) + else: + raise ValueError(f"Layer {layer_name} is not presented in GGUF model") def test_stablelm_q4_k_m(self): model = AutoModelForCausalLM.from_pretrained( @@ -708,6 +716,8 @@ def test_stablelm_weights_conversion_fp16(self): if layer_name in converted_state_dict: self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape) torch.testing.assert_close(original_params, converted_state_dict[layer_name]) + else: + raise ValueError(f"Layer {layer_name} is not presented in GGUF model") def test_starcoder2_weights_conversion_fp16(self): original_model = AutoModelForCausalLM.from_pretrained( @@ -727,10 +737,11 @@ def test_starcoder2_weights_conversion_fp16(self): original_state_dict = original_model.state_dict() for layer_name, original_params in original_state_dict.items(): - if layer_name in converted_state_dict and layer_name != "lm_head.weight": - # quantized models do not contain "lm_head.weight" layer + if layer_name in converted_state_dict: self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape) torch.testing.assert_close(original_params, converted_state_dict[layer_name]) + else: + raise ValueError(f"Layer {layer_name} is not presented in GGUF model") def test_starcoder2_q6_k(self): example_function_text = "def print_hello_world():" @@ -748,6 +759,47 @@ def test_starcoder2_q6_k(self): EXPECTED_TEXT = 'def print_hello_world():\n print("Hello World")\n\ndef print' self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_mamba_weights_conversion_fp16(self): + original_model = AutoModelForCausalLM.from_pretrained( + self.mamba_original_model_id, + torch_dtype=torch.float16, + ) + + converted_model = AutoModelForCausalLM.from_pretrained( + self.mamba_model_id, + gguf_file=self.fp16_mamba_model_id, + torch_dtype=torch.float16, + ) + + converted_state_dict = converted_model.state_dict() + original_state_dict = original_model.state_dict() + + for layer_name, original_params in original_state_dict.items(): + if layer_name in converted_state_dict: + self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape) + if "mixer.A_log" in layer_name: + # we should increase tolerance after exponential reversing + # and performing np.log(-weights) operation as numbers are slightly different + torch.testing.assert_close(original_params, converted_state_dict[layer_name], atol=1e-3, rtol=1e-3) + else: + torch.testing.assert_close(original_params, converted_state_dict[layer_name]) + else: + raise ValueError(f"Layer {layer_name} is not presented in GGUF model") + + def test_mamba_q6_k(self): + model = AutoModelForCausalLM.from_pretrained( + self.mamba_model_id, + gguf_file=self.q6_k_mamba_model_id, + torch_dtype=torch.float16, + ) + + tokenizer = AutoTokenizer.from_pretrained(self.mamba_model_id, gguf_file=self.q6_k_mamba_model_id) + text = tokenizer(self.example_text, return_tensors="pt")["input_ids"] + out = model.generate(text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello,I answerthe question.\n\nA" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_tokenization_xnli(self): import tqdm from datasets import load_dataset