Skip to content

Commit

Permalink
NeMo-Mistral-7B to HF-Mistral-7B.
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
  • Loading branch information
akoumpa committed Jan 17, 2024
1 parent e37328e commit fab45fc
Showing 1 changed file with 230 additions and 0 deletions.
230 changes: 230 additions & 0 deletions scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Conversion script to convert NeMo Mistral-7B checkpoints into HuggingFace checkpoint.
Example to run this conversion script:
python3 convert_nemo_mistral_7b_to_hf.py \
--in-file <path_to_nemo_checkpoints_folder> \
--out-file <path_to_output_hf_file>
"""

from argparse import ArgumentParser
from collections import OrderedDict

import torch
import torch.nn
from pytorch_lightning.trainer.trainer import Trainer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils import logging


def get_args():
parser = ArgumentParser()
parser.add_argument("--in-file", type=str, default=None, required=True, help="Path to NeMo Mistral-7B checkpoint")
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output HF checkpoint.")
parser.add_argument('--hf-model-name', type=str, default="mistralai/Mistral-7B-v0.1", help="Name of HF checkpoint")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
args = parser.parse_args()
return args


def load_config(hf_model_name, nemo_config):
hf_config = AutoConfig.from_pretrained(hf_model_name)
# SWA; nemo_config.window_size is list [left-bound, right-bound]
hf_config.sliding_window = nemo_config.window_size[0]
hf_config.max_position_embeddings = nemo_config.encoder_seq_length
hf_config.num_hidden_layers = nemo_config.num_layers
hf_config.hidden_size = nemo_config.hidden_size
hf_config.intermediate_size = nemo_config.ffn_hidden_size
hf_config.num_attention_heads = nemo_config.num_attention_heads
hf_config.max_position_embeddings = nemo_config.max_position_embeddings
hf_config.initializer_range = nemo_config.init_method_std
hf_config.rms_norm_eps = nemo_config.layernorm_epsilon
hf_config.num_key_value_heads = nemo_config.num_query_groups
if nemo_config.activation == 'fast-swiglu':
hf_config.activation = 'silu'
else:
logging.warning(f"Got unknown activation function {nemo_config.activation}")

hf_config.rope_theta = nemo_config['rotary_base']
return hf_config


def convert(in_file, precision=None) -> None:
"""
Convert NeMo checkpoint to HF checkpoint
"""

logging.info(f'Loading NeMo checkpoint from: {in_file}')

dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy())
model_config = MegatronGPTModel.restore_from(in_file, trainer=dummy_trainer, return_config=True)
model_config.tensor_model_parallel_size = 1
model_config.pipeline_model_parallel_size = 1
cpu_only = True
if cpu_only:
map_location = torch.device('cpu')
model_config.use_cpu_initialization = True
else:
map_location = None

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.

if cpu_only:
logging.info("******** Loading model on CPU. This will take a significant amount of time.")
model = MegatronGPTModel.restore_from(
in_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
)
ckpt = model.state_dict()
nemo_config = model.cfg

mcore_gpt = nemo_config.mcore_gpt
hidden_size = nemo_config.hidden_size
head_num = nemo_config.num_attention_heads
head_size = hidden_size // head_num

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'head_size' is unnecessary as it is
redefined
before this value is used.
num_layers = nemo_config.num_layers

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'num_layers' is unnecessary as it is
redefined
before this value is used.

if precision is None:
precision = model.cfg.precision
if precision in [32, "32"]:
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
logging.warning(f"Precision string {precision} is not recognized, falling back to fp32")
dtype = torch.float32 # fallback
param_to_weights = lambda param: param.to(dtype)

state_dict = OrderedDict()

hf_embed_weight_name = f'model.embed_tokens.weight'
if mcore_gpt:
embed_weights_base_name = f'model.embedding.word_embeddings.weight'
else:
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight'
state_dict[hf_embed_weight_name] = param_to_weights(ckpt[embed_weights_base_name])

if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num:
num_query_groups = head_num
else:
num_query_groups = nemo_config.num_query_groups
assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups'
if mcore_gpt:
assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.'

hidden_size = model.cfg.hidden_size
head_num = model.cfg.num_attention_heads
num_layers = model.cfg.num_layers
num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B

head_size = hidden_size // head_num
heads_per_group = head_num // num_query_groups
qkv_total_dim = head_num + 2 * num_query_groups

# Embedding
embed_weight = model.state_dict()[f'model.embedding.word_embeddings.weight']
embed_weights_base_name = f'model.embed_tokens.weight'
state_dict[embed_weights_base_name] = param_to_weights(embed_weight)

for l in range(int(num_layers)):
print(f"converting layer {l}")

qkv_weights = model.state_dict()[f'model.decoder.layers.{l}.self_attention.linear_qkv.weight']
qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size])

q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))

for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]:
weight_name = f'model.layers.{l}.self_attn.{name}.weight'
state_dict[weight_name] = param_to_weights(qkv_weights[slice].reshape(-1, hidden_size))

# attention dense
hf_o_weight_name = f'model.layers.{l}.self_attn.o_proj.weight'
if mcore_gpt:
o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight'
else:
o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight'
state_dict[hf_o_weight_name] = param_to_weights(ckpt[o_weight_base_name])

# # MLP
if mcore_gpt:
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight'
else:
raise Exception("not implemented")
gate_proj_weight, up_proj_weight = torch.chunk(ckpt[mlp_down_base_name], 2, dim=0)
hf_gate_proj_name = f'model.layers.{l}.mlp.gate_proj.weight'
hf_up_proj_name = f'model.layers.{l}.mlp.up_proj.weight'
state_dict[hf_gate_proj_name] = param_to_weights(gate_proj_weight)
state_dict[hf_up_proj_name] = param_to_weights(up_proj_weight)

hf_mlp_up_weight_name = f'model.layers.{l}.mlp.down_proj.weight'
if mcore_gpt:
mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight'
else:
raise Exception("not implemented")
state_dict[hf_mlp_up_weight_name] = param_to_weights(ckpt[mlp_up_base_name])

# LayerNorm
hf_input_ln_weight_name = f'model.layers.{l}.input_layernorm.weight'
if mcore_gpt:
input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
else:
input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight'

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
state_dict[hf_input_ln_weight_name] = param_to_weights(ckpt[input_ln_base_name])

hf_post_attn_ln_weight_name = f'model.layers.{l}.post_attention_layernorm.weight'
if mcore_gpt:
post_attn_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
else:
post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight'

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
state_dict[hf_post_attn_ln_weight_name] = param_to_weights(ckpt[post_attn_ln_base_name])

hf_final_ln_weight_name = 'model.norm.weight'
if mcore_gpt:
final_ln_base_name = 'model.decoder.final_layernorm.weight'
else:
final_ln_base_name = 'model.language_model.encoder.final_layernorm.weight'
state_dict[hf_final_ln_weight_name] = param_to_weights(ckpt[final_ln_base_name])

hf_output_layer_weight_name = 'lm_head.weight'
if mcore_gpt:
output_layer_base_name = 'model.output_layer.weight'
else:
output_layer_base_name = 'model.language_model.output_layer.weight'
state_dict[hf_output_layer_weight_name] = param_to_weights(ckpt[output_layer_base_name])
return state_dict, nemo_config


if __name__ == '__main__':
args = get_args()
hf_state_dict, nemo_config = convert(args.in_file, args.precision)

config = load_config(args.hf_model_name, nemo_config)
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(hf_state_dict)
model.save_pretrained(args.out_file)
hf_tokenizer = AutoTokenizer.from_pretrained(args.hf_model_name)
hf_tokenizer.save_pretrained(args.out_file)
logging.info(f'HF checkpoint saved to: {args.out_file}')

0 comments on commit fab45fc

Please sign in to comment.