diff --git a/keras_hub/src/utils/transformers/export_gemma_to_safetensor.py b/keras_hub/src/utils/transformers/export_gemma_to_safetensor.py new file mode 100644 index 0000000000..0c621aaed3 --- /dev/null +++ b/keras_hub/src/utils/transformers/export_gemma_to_safetensor.py @@ -0,0 +1,140 @@ +import json +import os +import shutil +import warnings + +import torch +from safetensors.torch import save_file + + +def convert_to_hf_config(keras_config): + hf_config = { + "vocab_size": keras_config.vocabulary_size, + "num_hidden_layers": keras_config.num_layers, + "num_attention_heads": keras_config.num_query_heads, + "num_key_value_heads": keras_config.num_key_value_heads, + "hidden_size": keras_config.hidden_dim, + "intermediate_size": keras_config.intermediate_dim // 2, + "head_dim": keras_config.head_dim, + "max_position_embeddings": 8192, + } + return hf_config + + +def export_to_hf(keras_model, path): + """This function converts a Keras Gemma model to Hugging Face format by: + - Extracting and mapping weights from the Keras backbone to safetensors. + - Saving the configuration as 'config.json'. + - Saving weights in 'model.safetensors'. + - Saving tokenizer assets. + Args: + keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert. + path: str. Path of the directory to which the safetensors file, + config and tokenizer will be saved. + """ + backbone = keras_model.backbone + hf_config = convert_to_hf_config(backbone) + + weights_dict = {} + + # Map token embedding + token_embedding = backbone.get_layer("token_embedding").get_weights()[0] + weights_dict["model.embed_tokens.weight"] = torch.from_numpy( + token_embedding + ) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"decoder_block_{i}") + + # Pre-attention normalization + pre_attn_norm = decoder_layer.pre_attention_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.input_layernorm.weight"] = ( + torch.from_numpy(pre_attn_norm) + ) + + # Attention query projection + query_kernel = decoder_layer.attention.query_dense.get_weights()[0] + query_kernel = ( + torch.from_numpy(query_kernel) + .permute(1, 0, 2) + .reshape(-1, backbone.hidden_dim) + .T + ) + weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel + + # Attention key projection + key_kernel = decoder_layer.attention.key_dense.get_weights()[0][0] + key_kernel = torch.from_numpy(key_kernel).T + weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = key_kernel + + # Attention value projection + value_kernel = decoder_layer.attention.value_dense.get_weights()[0][0] + value_kernel = torch.from_numpy(value_kernel).T + weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = value_kernel + + # Attention output projection + out_kernel = decoder_layer.attention.output_dense.get_weights()[0] + out_kernel = ( + torch.from_numpy(out_kernel) + .permute(2, 0, 1) + .reshape(backbone.hidden_dim, -1) + ) + weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel + + # Post-attention normalization + post_attn_norm = decoder_layer.pre_ffw_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( + torch.from_numpy(post_attn_norm) + ) + + # MLP gate projection + gate_kernel = decoder_layer.gating_ffw.get_weights()[0] + gate_kernel = torch.from_numpy(gate_kernel).T + weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel + + # MLP up projection + up_kernel = decoder_layer.gating_ffw_2.get_weights()[0] + up_kernel = torch.from_numpy(up_kernel).T + weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel + + # MLP down projection + down_kernel = decoder_layer.ffw_linear.get_weights()[0] + down_kernel = torch.from_numpy(down_kernel).T + weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel + + # Map final normalization + final_norm = backbone.get_layer("final_normalization").get_weights()[0] + weights_dict["model.norm.weight"] = torch.from_numpy(final_norm) + + # Tie lm_head.weight to embedding weights + weights_dict["lm_head.weight"] = weights_dict[ + "model.embed_tokens.weight" + ].clone() + + # Save config + os.makedirs(path, exist_ok=True) + config_path = os.path.join(path, "config.json") + with open(config_path, "w") as f: + json.dump(hf_config, f) + + # Make tensors contiguous before saving + weights_dict_contiguous = { + k: v.contiguous() for k, v in weights_dict.items() + } + + # Save weights + weights_path = os.path.join(path, "model.safetensors") + save_file(weights_dict_contiguous, weights_path) + + # Save tokenizer assets + keras_model.preprocessor.tokenizer.save_assets(path) + + # Rename vocabulary file + vocab_spm_path = os.path.join(path, "vocabulary.spm") + tokenizer_model_path = os.path.join(path, "tokenizer.model") + if os.path.exists(vocab_spm_path): + shutil.move(vocab_spm_path, tokenizer_model_path) + else: + warnings.warn( + f"{vocab_spm_path} not found. Tokenizer may not load correctly." + ) diff --git a/keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py b/keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py new file mode 100644 index 0000000000..249afa475a --- /dev/null +++ b/keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py @@ -0,0 +1,44 @@ +import os + +import pytest +import torch +from transformers import GemmaForCausalLM +from transformers import GemmaTokenizer + +from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM +from keras_hub.src.tests.test_case import TestCase +from keras_hub.src.utils.transformers.export_gemma_to_safetensor import ( + export_to_hf, +) + + +class TestGemmaExport(TestCase): + @pytest.mark.large + def test_export_to_hf(self): + # Load Keras model + keras_model = GemmaCausalLM.from_preset("gemma_2b_en") + input_text = "All hail RCB" + max_length = 25 + + # Export to Hugging Face format using self.tmp_path + export_path = os.path.join(self.get_temp_dir(), "export_to_hf") + export_to_hf(keras_model, export_path) + + # Load Hugging Face model and tokenizer + hf_model = GemmaForCausalLM.from_pretrained(export_path) + hf_tokenizer = GemmaTokenizer.from_pretrained(export_path) + + # Generate text with Keras model + keras_output = keras_model.generate(input_text, max_length=max_length) + + # Generate text with Hugging Face model + hf_inputs = hf_tokenizer(input_text, return_tensors="pt") + with torch.no_grad(): + hf_outputs = hf_model.generate( + **hf_inputs, max_length=max_length, do_sample=False + ) + hf_output_text = hf_tokenizer.decode( + hf_outputs[0], skip_special_tokens=True + ) + + self.assertEqual(keras_output, hf_output_text) diff --git a/requirements-common.txt b/requirements-common.txt index da331b567a..a98ed71301 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,3 +18,4 @@ sentencepiece tensorflow-datasets safetensors pillow +transformers