From cde10a8c33358f24f1670beffbdabb77f93b964f Mon Sep 17 00:00:00 2001 From: akoumpa <153118171+akoumpa@users.noreply.github.com> Date: Fri, 16 Feb 2024 09:11:18 -0800 Subject: [PATCH] NeMo-Mistral to HF converter bugfix. (#8353) Signed-off-by: Alexandros Koumparoulis --- .../nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py b/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py index 9e6403acd6c59..e78567d605546 100644 --- a/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py +++ b/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py @@ -114,6 +114,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' state_dict[hf_embed_weight_name] = param_to_weights(ckpt[embed_weights_base_name]) + head_num = model.cfg.num_attention_heads if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num: num_query_groups = head_num else: @@ -123,7 +124,6 @@ def convert(in_file, precision=None, cpu_only=True) -> None: assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' hidden_size = model.cfg.hidden_size - head_num = model.cfg.num_attention_heads num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B @@ -191,7 +191,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: hf_post_attn_ln_weight_name = f'model.layers.{l}.post_attention_layernorm.weight' if mcore_gpt: - post_attn_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight' else: post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' state_dict[hf_post_attn_ln_weight_name] = param_to_weights(ckpt[post_attn_ln_base_name])