From 99443044eb7fc95db009b92da44191d6e962b023 Mon Sep 17 00:00:00 2001 From: akoumpa <153118171+akoumpa@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:45:56 -0800 Subject: [PATCH] Mistral 7b conversion script (#8052) * Import script for mistral-7b. From mistral checkpoint not hf. Pending: support for block-diagonal attention mask. Signed-off-by: Alexandros Koumparoulis * add window_size to nemo_config. Signed-off-by: Alexandros Koumparoulis * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Alexandros Koumparoulis * Switch from Mistral checkpoint to HF-Mistral. Signed-off-by: Alexandros Koumparoulis * Force lowercase when checking for normalization type. Signed-off-by: Alexandros Koumparoulis * NeMo-Mistral-7B to HF-Mistral-7B. Signed-off-by: Alexandros Koumparoulis --------- Signed-off-by: Alexandros Koumparoulis Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper --- .../language_modeling/megatron_gpt_model.py | 2 +- .../convert_hf_mistral_7b_to_nemo.py | 341 ++++++++++++++++++ .../convert_nemo_mistral_7b_to_hf.py | 225 ++++++++++++ 3 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py create mode 100644 scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 85cee72a9a18..7c46b48e12a9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1596,7 +1596,7 @@ def build_transformer_config(self) -> TransformerConfig: For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. """ - normalization = self.cfg.get('normalization', 'layernorm') + normalization = self.cfg.get('normalization', 'layernorm').lower() layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p' if normalization == 'layernorm': normalization = 'LayerNorm' diff --git a/scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py b/scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py new file mode 100644 index 000000000000..89bf6cc27088 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py @@ -0,0 +1,341 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Conversion script to convert HuggingFace Mistral-7B checkpoints into nemo checkpoint. + Example to run this conversion script: + python convert_hf_mistral_7b_to_nemo.py \ + --in-file \ + --out-file \ + [--fast-swiglu\ +""" + + +import json +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +import torch.nn +from omegaconf import OmegaConf +from pytorch_lightning.core.saving import _load_state as ptl_load_state +from pytorch_lightning.trainer.trainer import Trainer +from sentencepiece import SentencePieceProcessor + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--in-file", type=str, default=None, required=True, help="Path to Huggingface Mistral-7b checkpoints", + ) + parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument("--precision", type=str, default="32", help="Model precision") + args = parser.parse_args() + return args + + +def load_model(cls, checkpoint, strict, **kwargs): + try: + if 'cfg' in kwargs: + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) + else: + model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs) + for name, module in model.named_parameters(): + if name in checkpoint['state_dict']: + module.data = checkpoint['state_dict'][name] + checkpoint['state_dict'].pop(name) + else: + print(f"Unexpected key: {name} not in checkpoint but in model.") + + for name, buffer in model.named_buffers(): + if name in checkpoint['state_dict']: + buffer.data = checkpoint['state_dict'][name] + checkpoint['state_dict'].pop(name) + + if len(checkpoint['state_dict'].keys()) != 0: + raise RuntimeError( + f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model." + ) + + # register the artifacts + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] + if cfg.tokenizer.model is not None: + model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model) + if cfg.tokenizer.vocab_file is not None: + model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file) + if cfg.tokenizer.merge_file is not None: + model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file) + finally: + cls._set_model_restore_state(is_being_restored=False) + return model + + +def load_config(mistral_config, tokenizer_path): + nemo_config = OmegaConf.load( + os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml') + ).model + # akoumparouli: verify this. + nemo_config.encoder_seq_length = mistral_config['sliding_window'] + nemo_config.num_layers = int(mistral_config['num_hidden_layers']) + nemo_config.hidden_size = mistral_config['hidden_size'] + nemo_config.ffn_hidden_size = mistral_config['intermediate_size'] + nemo_config.num_attention_heads = mistral_config['num_attention_heads'] + nemo_config.max_position_embeddings = mistral_config['max_position_embeddings'] + nemo_config.window_size = [mistral_config['sliding_window'], 0] + nemo_config.init_method_std = mistral_config['initializer_range'] + # RMSNorm's epsilon. + nemo_config.layernorm_epsilon = mistral_config['rms_norm_eps'] + nemo_config.normalization = 'rmsnorm' + + if 'num_key_value_heads' in mistral_config: + nemo_config.num_query_groups = mistral_config['num_key_value_heads'] + nemo_config.use_cpu_initialization = True + # Mistral uses SiLU, but it is the same as swish with beta = 1. + nemo_config.activation = 'fast-swiglu' + + nemo_config.tokenizer.model = tokenizer_path + # TODO(@akoumparouli): rope_scaling. + nemo_config['rotary_base'] = mistral_config['rope_theta'] + + base = 128 + while mistral_config['vocab_size'] % base != 0: + base //= 2 + nemo_config.make_vocab_size_divisible_by = base + + return nemo_config + + +def load_mistral_ckpt(dir): + params_file = os.path.join(dir, 'config.json') + assert os.path.exists(params_file) + with open(params_file, 'r') as fp: + model_args = json.load(fp) + + ckpt = OrderedDict() + ckpt['state_dict'] = OrderedDict() + for i in range(2): + ckpt_file = f'pytorch_model-0000{i+1}-of-00002.bin' + ckpt_path = os.path.join(dir, ckpt_file) + assert os.path.exists(ckpt_path) + ckpt.update(torch.load(ckpt_path)) + tokenizer_file = os.path.join(dir, 'tokenizer.model') + assert os.path.exists(tokenizer_file) + tokenizer = SentencePieceProcessor(model_file=tokenizer_file) + assert tokenizer.get_piece_size() == model_args['vocab_size'] + return model_args, ckpt, tokenizer + + +def convert(args): + logging.info(f"loading checkpoint {args.in_file}") + + model_args, ckpt, tokenizer = load_mistral_ckpt(args.in_file) + nemo_config = load_config(model_args, os.path.join(args.in_file, 'tokenizer.model')) + logging.info(f"loaded checkpoint {args.in_file}") + + if args.precision in ["32", "16"]: + precision = int(float(args.precision)) + elif args.precision in ["bf16", "bf16-mixed"]: + if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + precision = args.precision + else: + logging.warning("BF16 is not supported on this device. Using FP16 instead.") + precision = args.precision[2:] # prune bf in string + else: + precision = args.precision + + plugins = [] + if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: + scaler = None + if precision in [16, '16', '16-mixed']: + scaler = GradScaler( + init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + growth_interval=nemo_config.get('native_amp_growth_interval', 1000), + hysteresis=nemo_config.get('hysteresis', 2), + ) + # MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed + plugin_precision = '16-mixed' + else: + plugin_precision = 'bf16-mixed' + + if nemo_config.get('megatron_amp_O2', False): + plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) + + if precision == 32: + dtype = torch.float32 + elif precision in [16, "16", "16-mixed"]: + dtype = torch.float16 + elif precision in ["bf16", "bf16-mixed"]: + dtype = torch.bfloat16 + else: + dtype = torch.float32 # fallback + + nemo_config.precision = precision + logging.info(f"nemo_config: {nemo_config}") + + trainer = Trainer(plugins=plugins, accelerator='cpu', precision=precision, strategy=NLPDDPStrategy()) + + hidden_size = nemo_config.hidden_size + head_num = nemo_config.num_attention_heads + head_size = hidden_size // head_num + num_layers = nemo_config.num_layers + + mcore_gpt = nemo_config.mcore_gpt + + assert mcore_gpt == nemo_config.get( + 'transformer_engine', False + ), "mcore_gpt transformer_engine must be enabled (or disabled) together." + + param_to_weights = lambda param: param.float() + + checkpoint = OrderedDict() + checkpoint['state_dict'] = OrderedDict() + + embed_weight = ckpt[f'model.embed_tokens.weight'] + if mcore_gpt: + embed_weights_base_name = f'model.embedding.word_embeddings.weight' + else: + embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' + checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight) + + if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num: + num_query_groups = head_num + else: + num_query_groups = nemo_config.num_query_groups + assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups' + if mcore_gpt: + assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' + + for l in range(int(num_layers)): + print(f"converting layer {l}") + old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].view(*new_q_tensor_shape) + k = ckpt[f'model.layers.{l}.self_attn.k_proj.weight'].view(*new_kv_tensor_shape) + v = ckpt[f'model.layers.{l}.self_attn.v_proj.weight'].view(*new_kv_tensor_shape) + + # Note: we assume wq & wk have been appropriately transposed to work with + # NeMo/Megatron's rotary embedding. The reference checkpoint/implementation + # will not work OotB without transposing wq/wk matrices. + heads_per_group = head_num // num_query_groups + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + if mcore_gpt: + qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight' + else: + qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight' + checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights) + + # attention dense + o_weight = ckpt[f'model.layers.{l}.self_attn.o_proj.weight'] + if mcore_gpt: + o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight' + else: + o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight' + checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight) + + # MLP + mlp_down_weight = ckpt[f'model.layers.{l}.mlp.gate_proj.weight'] + mlp_gate_weight = ckpt[f'model.layers.{l}.mlp.up_proj.weight'] + if mcore_gpt: + mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight' + else: + mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight' + mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0) + checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight) + + mlp_up_weight = ckpt[f'model.layers.{l}.mlp.down_proj.weight'] + if mcore_gpt: + mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight' + else: + mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight' + checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight) + + # LayerNorm + input_ln_weight = ckpt[f'model.layers.{l}.input_layernorm.weight'] + if mcore_gpt: + input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + else: + input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' + checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight) + + post_attn_ln_weight = ckpt[f'model.layers.{l}.post_attention_layernorm.weight'] + if mcore_gpt: + post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight' + else: + post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' + checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight) + + print(f"done layer {l}") + + final_ln_weight = ckpt[f'model.norm.weight'] + if mcore_gpt: + final_ln_base_name = f'model.decoder.final_layernorm.weight' + else: + final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight' + checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight) + + output_layer_weight = ckpt[f'lm_head.weight'] + if mcore_gpt: + output_layer_base_name = f'model.output_layer.weight' + else: + output_layer_base_name = f'model.language_model.output_layer.weight' + checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight) + + checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config + del ckpt + + if nemo_config.get('megatron_amp_O2', False): + keys = list(checkpoint['state_dict'].keys()) + for key in keys: + checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key) + + model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer) + + model._save_restore_connector = NLPSaveRestoreConnector() + + # cast to target precision and disable cpu init + model = model.to(dtype=dtype) + model.cfg.use_cpu_initialization = False + + model.save_to(args.out_file) + logging.info(f'NeMo model saved to: {args.out_file}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py b/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py new file mode 100644 index 000000000000..9e6403acd6c5 --- /dev/null +++ b/scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py @@ -0,0 +1,225 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Conversion script to convert NeMo Mistral-7B checkpoints into HuggingFace checkpoint. + Example to run this conversion script: + python3 convert_nemo_mistral_7b_to_hf.py \ + --in-file \ + --out-file +""" + +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +import torch.nn +from pytorch_lightning.trainer.trainer import Trainer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--in-file", type=str, default=None, required=True, help="Path to NeMo Mistral-7B checkpoint") + parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output HF checkpoint.") + parser.add_argument('--hf-model-name', type=str, default="mistralai/Mistral-7B-v0.1", help="Name of HF checkpoint") + parser.add_argument("--precision", type=str, default="32", help="Model precision") + args = parser.parse_args() + return args + + +def load_config(hf_model_name, nemo_config): + hf_config = AutoConfig.from_pretrained(hf_model_name) + # SWA; nemo_config.window_size is list [left-bound, right-bound] + hf_config.sliding_window = nemo_config.window_size[0] + hf_config.max_position_embeddings = nemo_config.encoder_seq_length + hf_config.num_hidden_layers = nemo_config.num_layers + hf_config.hidden_size = nemo_config.hidden_size + hf_config.intermediate_size = nemo_config.ffn_hidden_size + hf_config.num_attention_heads = nemo_config.num_attention_heads + hf_config.max_position_embeddings = nemo_config.max_position_embeddings + hf_config.initializer_range = nemo_config.init_method_std + hf_config.rms_norm_eps = nemo_config.layernorm_epsilon + hf_config.num_key_value_heads = nemo_config.num_query_groups + if nemo_config.activation == 'fast-swiglu': + hf_config.activation = 'silu' + else: + logging.warning(f"Got unknown activation function {nemo_config.activation}") + + hf_config.rope_theta = nemo_config['rotary_base'] + return hf_config + + +def convert(in_file, precision=None, cpu_only=True) -> None: + """ + Convert NeMo checkpoint to HF checkpoint + """ + + logging.info(f'Loading NeMo checkpoint from: {in_file}') + + dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + model_config = MegatronGPTModel.restore_from(in_file, trainer=dummy_trainer, return_config=True) + model_config.tensor_model_parallel_size = 1 + model_config.pipeline_model_parallel_size = 1 + if cpu_only: + map_location = torch.device('cpu') + model_config.use_cpu_initialization = True + else: + map_location = None + + if cpu_only: + logging.info("******** Loading model on CPU. This will take a significant amount of time.") + model = MegatronGPTModel.restore_from( + in_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location + ) + ckpt = model.state_dict() + nemo_config = model.cfg + + mcore_gpt = nemo_config.mcore_gpt + + if precision is None: + precision = model.cfg.precision + if precision in [32, "32"]: + dtype = torch.float32 + elif precision in [16, "16", "16-mixed"]: + dtype = torch.float16 + elif precision in ["bf16", "bf16-mixed"]: + dtype = torch.bfloat16 + else: + logging.warning(f"Precision string {precision} is not recognized, falling back to fp32") + dtype = torch.float32 # fallback + param_to_weights = lambda param: param.to(dtype) + + state_dict = OrderedDict() + + hf_embed_weight_name = f'model.embed_tokens.weight' + if mcore_gpt: + embed_weights_base_name = f'model.embedding.word_embeddings.weight' + else: + embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight' + state_dict[hf_embed_weight_name] = param_to_weights(ckpt[embed_weights_base_name]) + + if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num: + num_query_groups = head_num + else: + num_query_groups = nemo_config.num_query_groups + assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups' + if mcore_gpt: + assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.' + + hidden_size = model.cfg.hidden_size + head_num = model.cfg.num_attention_heads + num_layers = model.cfg.num_layers + num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B + + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + # Embedding + embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight'] + embed_weights_base_name = f'model.embed_tokens.weight' + state_dict[embed_weights_base_name] = param_to_weights(embed_weight) + + for l in range(int(num_layers)): + print(f"converting layer {l}") + + qkv_weights = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{l}.self_attn.{name}.weight' + state_dict[weight_name] = param_to_weights(qkv_weights[slice].reshape(-1, hidden_size)) + + # attention dense + hf_o_weight_name = f'model.layers.{l}.self_attn.o_proj.weight' + if mcore_gpt: + o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight' + else: + o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight' + state_dict[hf_o_weight_name] = param_to_weights(ckpt[o_weight_base_name]) + + # # MLP + if mcore_gpt: + mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight' + else: + raise Exception("not implemented") + gate_proj_weight, up_proj_weight = torch.chunk(ckpt[mlp_down_base_name], 2, dim=0) + hf_gate_proj_name = f'model.layers.{l}.mlp.gate_proj.weight' + hf_up_proj_name = f'model.layers.{l}.mlp.up_proj.weight' + state_dict[hf_gate_proj_name] = param_to_weights(gate_proj_weight) + state_dict[hf_up_proj_name] = param_to_weights(up_proj_weight) + + hf_mlp_up_weight_name = f'model.layers.{l}.mlp.down_proj.weight' + if mcore_gpt: + mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight' + else: + raise Exception("not implemented") + state_dict[hf_mlp_up_weight_name] = param_to_weights(ckpt[mlp_up_base_name]) + + # LayerNorm + hf_input_ln_weight_name = f'model.layers.{l}.input_layernorm.weight' + if mcore_gpt: + input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + else: + input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight' + state_dict[hf_input_ln_weight_name] = param_to_weights(ckpt[input_ln_base_name]) + + hf_post_attn_ln_weight_name = f'model.layers.{l}.post_attention_layernorm.weight' + if mcore_gpt: + post_attn_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight' + else: + post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight' + state_dict[hf_post_attn_ln_weight_name] = param_to_weights(ckpt[post_attn_ln_base_name]) + + hf_final_ln_weight_name = 'model.norm.weight' + if mcore_gpt: + final_ln_base_name = 'model.decoder.final_layernorm.weight' + else: + final_ln_base_name = 'model.language_model.encoder.final_layernorm.weight' + state_dict[hf_final_ln_weight_name] = param_to_weights(ckpt[final_ln_base_name]) + + hf_output_layer_weight_name = 'lm_head.weight' + if mcore_gpt: + output_layer_base_name = 'model.output_layer.weight' + else: + output_layer_base_name = 'model.language_model.output_layer.weight' + state_dict[hf_output_layer_weight_name] = param_to_weights(ckpt[output_layer_base_name]) + return state_dict, nemo_config + + +if __name__ == '__main__': + args = get_args() + hf_state_dict, nemo_config = convert(args.in_file, args.precision) + + config = load_config(args.hf_model_name, nemo_config) + model = AutoModelForCausalLM.from_config(config) + model.load_state_dict(hf_state_dict) + model.save_pretrained(args.out_file) + hf_tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name) + hf_tokenizer.save_pretrained(args.out_file) + logging.info(f'HF checkpoint saved to: {args.out_file}')