Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral 7b conversion script #8052

Merged
merged 8 commits into from
Jan 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ def build_transformer_config(self) -> TransformerConfig:
For attributes in TransformerConfig that are not in the nemo model config, we add custom logic.
"""

normalization = self.cfg.get('normalization', 'layernorm')
normalization = self.cfg.get('normalization', 'layernorm').lower()
layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p'
if normalization == 'layernorm':
normalization = 'LayerNorm'
Expand Down
341 changes: 341 additions & 0 deletions scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Conversion script to convert HuggingFace Mistral-7B checkpoints into nemo checkpoint.
Example to run this conversion script:
python convert_hf_mistral_7b_to_nemo.py \
--in-file <path_to_mistral_checkpoints_folder> \
--out-file <path_to_output_nemo_file> \
[--fast-swiglu\
"""


import json
import os
from argparse import ArgumentParser
from collections import OrderedDict

import torch
import torch.nn
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.trainer.trainer import Trainer
from sentencepiece import SentencePieceProcessor

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.utils import logging


def get_args():
parser = ArgumentParser()
parser.add_argument(
"--in-file", type=str, default=None, required=True, help="Path to Huggingface Mistral-7b checkpoints",
)
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
args = parser.parse_args()
return args


def load_model(cls, checkpoint, strict, **kwargs):
try:
if 'cfg' in kwargs:
model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
else:
model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs)
for name, module in model.named_parameters():
if name in checkpoint['state_dict']:
module.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)
else:
print(f"Unexpected key: {name} not in checkpoint but in model.")

for name, buffer in model.named_buffers():
if name in checkpoint['state_dict']:
buffer.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)

if len(checkpoint['state_dict'].keys()) != 0:
raise RuntimeError(
f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model."
)

# register the artifacts
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if cfg.tokenizer.model is not None:
model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model)
if cfg.tokenizer.vocab_file is not None:
model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file)
if cfg.tokenizer.merge_file is not None:
model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file)
finally:
cls._set_model_restore_state(is_being_restored=False)
return model


def load_config(mistral_config, tokenizer_path):
nemo_config = OmegaConf.load(
os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml')
).model
# akoumparouli: verify this.
nemo_config.encoder_seq_length = mistral_config['sliding_window']
nemo_config.num_layers = int(mistral_config['num_hidden_layers'])
nemo_config.hidden_size = mistral_config['hidden_size']
nemo_config.ffn_hidden_size = mistral_config['intermediate_size']
nemo_config.num_attention_heads = mistral_config['num_attention_heads']
nemo_config.max_position_embeddings = mistral_config['max_position_embeddings']
nemo_config.window_size = [mistral_config['sliding_window'], 0]
nemo_config.init_method_std = mistral_config['initializer_range']
# RMSNorm's epsilon.
nemo_config.layernorm_epsilon = mistral_config['rms_norm_eps']
nemo_config.normalization = 'rmsnorm'

if 'num_key_value_heads' in mistral_config:
nemo_config.num_query_groups = mistral_config['num_key_value_heads']
nemo_config.use_cpu_initialization = True
# Mistral uses SiLU, but it is the same as swish with beta = 1.
nemo_config.activation = 'fast-swiglu'

nemo_config.tokenizer.model = tokenizer_path
# TODO(@akoumparouli): rope_scaling.
nemo_config['rotary_base'] = mistral_config['rope_theta']

base = 128
while mistral_config['vocab_size'] % base != 0:
base //= 2
nemo_config.make_vocab_size_divisible_by = base

return nemo_config


def load_mistral_ckpt(dir):
params_file = os.path.join(dir, 'config.json')
assert os.path.exists(params_file)
with open(params_file, 'r') as fp:
model_args = json.load(fp)

ckpt = OrderedDict()
ckpt['state_dict'] = OrderedDict()
for i in range(2):
ckpt_file = f'pytorch_model-0000{i+1}-of-00002.bin'
ckpt_path = os.path.join(dir, ckpt_file)
assert os.path.exists(ckpt_path)
ckpt.update(torch.load(ckpt_path))
tokenizer_file = os.path.join(dir, 'tokenizer.model')
assert os.path.exists(tokenizer_file)
tokenizer = SentencePieceProcessor(model_file=tokenizer_file)
assert tokenizer.get_piece_size() == model_args['vocab_size']
return model_args, ckpt, tokenizer


def convert(args):
logging.info(f"loading checkpoint {args.in_file}")

model_args, ckpt, tokenizer = load_mistral_ckpt(args.in_file)
nemo_config = load_config(model_args, os.path.join(args.in_file, 'tokenizer.model'))
logging.info(f"loaded checkpoint {args.in_file}")

if args.precision in ["32", "16"]:
precision = int(float(args.precision))
elif args.precision in ["bf16", "bf16-mixed"]:
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
precision = args.precision
else:
logging.warning("BF16 is not supported on this device. Using FP16 instead.")
precision = args.precision[2:] # prune bf in string
else:
precision = args.precision

plugins = []
if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']:
scaler = None
if precision in [16, '16', '16-mixed']:
scaler = GradScaler(
init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32),
growth_interval=nemo_config.get('native_amp_growth_interval', 1000),
hysteresis=nemo_config.get('hysteresis', 2),
)
# MixedPrecisionPlugin in PTL >= 2.0 requires precision to be 16-mixed or bf16-mixed
plugin_precision = '16-mixed'
else:
plugin_precision = 'bf16-mixed'

if nemo_config.get('megatron_amp_O2', False):
plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler))

if precision == 32:
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
dtype = torch.float32 # fallback

nemo_config.precision = precision
logging.info(f"nemo_config: {nemo_config}")

trainer = Trainer(plugins=plugins, accelerator='cpu', precision=precision, strategy=NLPDDPStrategy())

hidden_size = nemo_config.hidden_size
head_num = nemo_config.num_attention_heads
head_size = hidden_size // head_num
num_layers = nemo_config.num_layers

mcore_gpt = nemo_config.mcore_gpt

assert mcore_gpt == nemo_config.get(
'transformer_engine', False
), "mcore_gpt transformer_engine must be enabled (or disabled) together."

param_to_weights = lambda param: param.float()

checkpoint = OrderedDict()
checkpoint['state_dict'] = OrderedDict()

embed_weight = ckpt[f'model.embed_tokens.weight']
if mcore_gpt:
embed_weights_base_name = f'model.embedding.word_embeddings.weight'
else:
embed_weights_base_name = f'model.language_model.embedding.word_embeddings.weight'
checkpoint['state_dict'][embed_weights_base_name] = param_to_weights(embed_weight)

if nemo_config.num_query_groups is None or nemo_config.num_query_groups == head_num:
num_query_groups = head_num
else:
num_query_groups = nemo_config.num_query_groups
assert head_num % num_query_groups == 0, 'head_num must be divisible by num_query_groups'
if mcore_gpt:
assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.'

for l in range(int(num_layers)):
print(f"converting layer {l}")
old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size()
new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:]
new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:]

q = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].view(*new_q_tensor_shape)
k = ckpt[f'model.layers.{l}.self_attn.k_proj.weight'].view(*new_kv_tensor_shape)
v = ckpt[f'model.layers.{l}.self_attn.v_proj.weight'].view(*new_kv_tensor_shape)

# Note: we assume wq & wk have been appropriately transposed to work with
# NeMo/Megatron's rotary embedding. The reference checkpoint/implementation
# will not work OotB without transposing wq/wk matrices.
heads_per_group = head_num // num_query_groups
qkv_weights_l = []
for i in range(num_query_groups):
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
qkv_weights_l.append(k[i : i + 1, :, :])
qkv_weights_l.append(v[i : i + 1, :, :])
qkv_weights = torch.cat(qkv_weights_l)
assert qkv_weights.ndim == 3, qkv_weights.shape
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape
qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])
if mcore_gpt:
qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight'
else:
qkv_weights_base_name = f'model.language_model.encoder.layers.{l}.self_attention.query_key_value.weight'
checkpoint['state_dict'][qkv_weights_base_name] = param_to_weights(qkv_weights)

# attention dense
o_weight = ckpt[f'model.layers.{l}.self_attn.o_proj.weight']
if mcore_gpt:
o_weight_base_name = f'model.decoder.layers.{l}.self_attention.linear_proj.weight'
else:
o_weight_base_name = f'model.language_model.encoder.layers.{l}.self_attention.dense.weight'
checkpoint['state_dict'][o_weight_base_name] = param_to_weights(o_weight)

# MLP
mlp_down_weight = ckpt[f'model.layers.{l}.mlp.gate_proj.weight']
mlp_gate_weight = ckpt[f'model.layers.{l}.mlp.up_proj.weight']
if mcore_gpt:
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight'
else:
mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight'
mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0)
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)

mlp_up_weight = ckpt[f'model.layers.{l}.mlp.down_proj.weight']
if mcore_gpt:
mlp_up_base_name = f'model.decoder.layers.{l}.mlp.linear_fc2.weight'
else:
mlp_up_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_4h_to_h.weight'
checkpoint['state_dict'][mlp_up_base_name] = param_to_weights(mlp_up_weight)

# LayerNorm
input_ln_weight = ckpt[f'model.layers.{l}.input_layernorm.weight']
if mcore_gpt:
input_ln_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight'
else:
input_ln_base_name = f'model.language_model.encoder.layers.{l}.input_layernorm.weight'
checkpoint['state_dict'][input_ln_base_name] = param_to_weights(input_ln_weight)

post_attn_ln_weight = ckpt[f'model.layers.{l}.post_attention_layernorm.weight']
if mcore_gpt:
post_attn_ln_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight'
else:
post_attn_ln_base_name = f'model.language_model.encoder.layers.{l}.post_attention_layernorm.weight'
checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)

print(f"done layer {l}")

final_ln_weight = ckpt[f'model.norm.weight']
if mcore_gpt:
final_ln_base_name = f'model.decoder.final_layernorm.weight'
else:
final_ln_base_name = f'model.language_model.encoder.final_layernorm.weight'
checkpoint['state_dict'][final_ln_base_name] = param_to_weights(final_ln_weight)

output_layer_weight = ckpt[f'lm_head.weight']
if mcore_gpt:
output_layer_base_name = f'model.output_layer.weight'
else:
output_layer_base_name = f'model.language_model.output_layer.weight'
checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight)

checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config
del ckpt

if nemo_config.get('megatron_amp_O2', False):
keys = list(checkpoint['state_dict'].keys())
for key in keys:
checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key)

model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function load_model is a bit convolved, can you avoid using it? I recommend instantiating model with sth like

model = MegatronGPTModel(cfg, trainer)
missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False)

see also #7977.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the model.load_state_dict with strict=False verifies if any weight were loaded at all? I'm not implying that load_model performs this check, rather want to understand if this is something we care doing at this stage of the script.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That PR #7977 has been merged. Could you please replace load_model with load_state_dict_helper? The former is really cluttered and unnecessarily saves two tokenizers -- the one stored in tokenizer.tokenizer_model is not needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And as for strict=False there are checks for missing_keys and unexpected_keys lists in load_state_dict_helper that will complain if any expected weights are not loaded, or are superfluous.


model._save_restore_connector = NLPSaveRestoreConnector()

# cast to target precision and disable cpu init
model = model.to(dtype=dtype)
model.cfg.use_cpu_initialization = False

model.save_to(args.out_file)
logging.info(f'NeMo model saved to: {args.out_file}')


if __name__ == '__main__':
args = get_args()
convert(args)
Loading
Loading