Skip to content

Commit

Permalink
[NLP] Improve and unify loading state_dict for community models (#7977)
Browse files Browse the repository at this point in the history
* Improve and unify loading state_dict for community models

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
janekl and pre-commit-ci[bot] authored Jan 19, 2024
1 parent dab6a04 commit e329575
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 59 deletions.
22 changes: 22 additions & 0 deletions nemo/collections/nlp/parts/utils_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'tensor2list',
'plot_confusion_matrix',
'get_classification_report',
'load_state_dict_helper',
]

import os
Expand All @@ -28,6 +29,8 @@
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
from sklearn.metrics import classification_report, confusion_matrix
from torch import Tensor

Expand Down Expand Up @@ -207,3 +210,22 @@ def activation_to_func(activation: str, openai_gelu: bool = False, onnx_safe: bo
activation_func = squared_relu

return activation_func


def load_state_dict_helper(cls, cfg: DictConfig, trainer: Trainer, state_dict: Dict[str, torch.Tensor]):
"""Load state_dict for converted community, for example, HuggingFace models."""
model = cls(cfg, trainer)

missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
# Keys ending with '_extra_state' are related to Transformer Engine internals
missing_keys_non_extra = [key for key in missing_keys if not key.endswith("_extra_state")]
if missing_keys_non_extra:
logging.critical("Missing keys were detected during the load, something has gone wrong. Aborting.")
raise RuntimeError(f"Missing keys: \n{missing_keys_non_extra}")

if unexpected_keys:
logging.critical("Unexpected keys were detected which should not happen. Aborting.")
raise RuntimeError(f"Unexpected keys: \n{unexpected_keys}")

return model
44 changes: 2 additions & 42 deletions scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,18 @@

import torch
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.trainer.trainer import Trainer
from transformers import LlamaForCausalLM, LlamaTokenizer


from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel

from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper, torch_dtype_from_precision
from nemo.utils import logging


Expand All @@ -55,47 +52,10 @@ def get_args():
return args


def load_model(cls, checkpoint, strict, **kwargs):
try:
if 'cfg' in kwargs:
model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
else:
model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs)
for name, module in model.named_parameters():
if name in checkpoint['state_dict']:
module.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)
else:
print(f"Unexpected key: {name} not in checkpoint but in model.")

for name, buffer in model.named_buffers():
if name in checkpoint['state_dict']:
buffer.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)

if len(checkpoint['state_dict'].keys()) != 0:
raise RuntimeError(
f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model."
)

# register the artifacts
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if cfg.tokenizer.model is not None:
model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model)
if cfg.tokenizer.vocab_file is not None:
model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file)
if cfg.tokenizer.merge_file is not None:
model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file)
finally:
cls._set_model_restore_state(is_being_restored=False)
return model


def load_config(llama_config):
nemo_config = OmegaConf.load(
os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml')
).model

if llama_config.get('rope_theta', None):
nemo_config['rotary_base'] = llama_config['rope_theta']
nemo_config.encoder_seq_length = llama_config['max_position_embeddings']
Expand Down Expand Up @@ -301,7 +261,7 @@ def convert(args):
for key in keys:
checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key)

model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)
model = load_state_dict_helper(MegatronGPTModel, nemo_config, trainer, checkpoint['state_dict'])

model._save_restore_connector = NLPSaveRestoreConnector()

Expand Down
21 changes: 4 additions & 17 deletions scripts/nlp_language_modeling/convert_starcoder_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper
from nemo.utils import logging


Expand Down Expand Up @@ -185,30 +186,16 @@ def get_new_key(old_key):

trainer = pl.Trainer(**trainer_dict)

logging.info("Creating Megatron model...")
model = MegatronGPTModel(omega_cfg, trainer)
logging.info(f"Created model:\n{model}")

logging.info("Loading HuggingFace model...")
model_hf = AutoModelForCausalLM.from_pretrained(args.input)
logging.info(f"Loaded model:\n{model_hf}")

state_dict_hf = model_hf.state_dict()
convert_dict = convert_state_dict(state_dict_hf, amp=omega_cfg.megatron_amp_O2)

logging.info("Loading state dict...")
missing_keys, unexpected_keys = model.load_state_dict(convert_dict, strict=False)

if missing_keys:
# Keys ending with '_extra_state' are related to Transformer Engine internals
missing_keys_non_extra = [key for key in missing_keys if not key.endswith("_extra_state")]
if missing_keys_non_extra:
logging.critical("Missing keys were detected during the load, something has gone wrong. Aborting.")
raise RuntimeError(f"Missing keys: \n{missing_keys_non_extra}")

if unexpected_keys:
logging.critical("Unexpected keys were detected which should not happen. Aborting.")
raise RuntimeError(f"Unexpected keys: \n{unexpected_keys}")
logging.info("Creating Megatron model...")
model = load_state_dict_helper(MegatronGPTModel, omega_cfg, trainer, convert_dict)
logging.info(f"Created model:\n{model}")

logging.info("Saving model...")
# We make sure that the tokenizer can be instantiated later regardless of args.input
Expand Down

0 comments on commit e329575

Please sign in to comment.