Skip to content

Commit

Permalink
Add support for Gemma 2 checkpoints
Browse files Browse the repository at this point in the history
This is sadly a little hacky as the flax support for Gemma 2 is not yet
complete. So output checking will not match up, but we can still convert
checkpoints.
  • Loading branch information
mattdangerw committed Jul 19, 2024
1 parent b0c21b3 commit 88f80b0
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions tools/checkpoint_conversion/convert_gemma_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,19 @@ def download_flax_model(handle):
return kagglehub.model_download(handle)


def convert_model(flax_config, vocab_size):
def convert_model(flax_config, flax_params, vocab_size):
kwargs = {}
# Hacky way to infer Gemma 2 until Flax actually adds support.
if "post_attention_norm" in flax_params["transformer"]["layer_0"]:
kwargs = {
"query_head_dim_normalize": False,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"final_logit_soft_cap": 30,
"attention_logit_soft_cap": 50,
"use_sliding_window_attention": True,
"sliding_window_size": 4096,
}
return keras_nlp.models.GemmaBackbone(
vocabulary_size=vocab_size,
num_layers=flax_config.num_layers,
Expand All @@ -95,6 +107,7 @@ def convert_model(flax_config, vocab_size):
hidden_dim=flax_config.embed_dim,
intermediate_dim=flax_config.hidden_dim * 2,
head_dim=flax_config.head_dim,
**kwargs,
)


Expand Down Expand Up @@ -123,6 +136,15 @@ def convert_weights(keras_model, flax_config, flax_params):
[flax_block["pre_ffw_norm"]["scale"]]
)

if "post_attention_norm" in flax_block:
keras_block.post_attention_norm.set_weights(
[flax_block["post_attention_norm"]["scale"]]
)
if "post_ffw_norm" in flax_block:
keras_block.post_ffw_norm.set_weights(
[flax_block["post_ffw_norm"]["scale"]]
)

keras_block.gating_ffw.set_weights(
[flax_block["mlp"]["gating_einsum"][0]]
)
Expand Down Expand Up @@ -236,7 +258,7 @@ def main(_):

keras_tokenizer = convert_tokenizer(proto_path)
vocab_size = keras_tokenizer.vocabulary_size()
keras_model = convert_model(flax_config, vocab_size)
keras_model = convert_model(flax_config, flax_params, vocab_size)
print("✅ Keras model loaded")

convert_weights(keras_model, flax_config, flax_params)
Expand Down

0 comments on commit 88f80b0

Please sign in to comment.