Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NLP] Improve and unify loading state_dict for community models #7977

Merged
merged 9 commits into from
Jan 19, 2024
22 changes: 22 additions & 0 deletions nemo/collections/nlp/parts/utils_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'tensor2list',
'plot_confusion_matrix',
'get_classification_report',
'load_state_dict_helper',
]

import os
Expand All @@ -28,6 +29,8 @@
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
from sklearn.metrics import classification_report, confusion_matrix
from torch import Tensor

Expand Down Expand Up @@ -207,3 +210,22 @@ def activation_to_func(activation: str, openai_gelu: bool = False, onnx_safe: bo
activation_func = squared_relu

return activation_func


def load_state_dict_helper(cls, cfg: DictConfig, trainer: Trainer, state_dict: Dict[str, torch.Tensor]):
"""Load state_dict for converted community, for example, HuggingFace models."""
model = cls(cfg, trainer)

missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
# Keys ending with '_extra_state' are related to Transformer Engine internals
missing_keys_non_extra = [key for key in missing_keys if not key.endswith("_extra_state")]
if missing_keys_non_extra:
logging.critical("Missing keys were detected during the load, something has gone wrong. Aborting.")
raise RuntimeError(f"Missing keys: \n{missing_keys_non_extra}")

if unexpected_keys:
logging.critical("Unexpected keys were detected which should not happen. Aborting.")
raise RuntimeError(f"Unexpected keys: \n{unexpected_keys}")

return model
44 changes: 2 additions & 42 deletions scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,18 @@

import torch
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.trainer.trainer import Trainer
from transformers import LlamaForCausalLM, LlamaTokenizer


from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel

from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision
from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper, torch_dtype_from_precision
from nemo.utils import logging


Expand All @@ -55,47 +52,10 @@ def get_args():
return args


def load_model(cls, checkpoint, strict, **kwargs):
try:
if 'cfg' in kwargs:
model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
else:
model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs)
for name, module in model.named_parameters():
if name in checkpoint['state_dict']:
module.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)
else:
print(f"Unexpected key: {name} not in checkpoint but in model.")

for name, buffer in model.named_buffers():
if name in checkpoint['state_dict']:
buffer.data = checkpoint['state_dict'][name]
checkpoint['state_dict'].pop(name)

if len(checkpoint['state_dict'].keys()) != 0:
raise RuntimeError(
f"Additional keys: {checkpoint['state_dict'].keys()} in checkpoint but not in model."
)

# register the artifacts
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
if cfg.tokenizer.model is not None:
model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model)
if cfg.tokenizer.vocab_file is not None:
model.register_artifact("tokenizer.vocab_file", cfg.tokenizer.vocab_file)
if cfg.tokenizer.merge_file is not None:
model.register_artifact("tokenizer.merge_file", cfg.tokenizer.merge_file)
finally:
cls._set_model_restore_state(is_being_restored=False)
return model


def load_config(llama_config):
nemo_config = OmegaConf.load(
os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml')
).model

if llama_config.get('rope_theta', None):
nemo_config['rotary_base'] = llama_config['rope_theta']
nemo_config.encoder_seq_length = llama_config['max_position_embeddings']
Expand Down Expand Up @@ -301,7 +261,7 @@ def convert(args):
for key in keys:
checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key)

model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)
model = load_state_dict_helper(MegatronGPTModel, nemo_config, trainer, checkpoint['state_dict'])

model._save_restore_connector = NLPSaveRestoreConnector()

Expand Down
21 changes: 4 additions & 17 deletions scripts/nlp_language_modeling/convert_starcoder_hf_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.collections.nlp.parts.utils_funcs import load_state_dict_helper
from nemo.utils import logging


Expand Down Expand Up @@ -185,30 +186,16 @@ def get_new_key(old_key):

trainer = pl.Trainer(**trainer_dict)

logging.info("Creating Megatron model...")
model = MegatronGPTModel(omega_cfg, trainer)
logging.info(f"Created model:\n{model}")

logging.info("Loading HuggingFace model...")
model_hf = AutoModelForCausalLM.from_pretrained(args.input)
logging.info(f"Loaded model:\n{model_hf}")

state_dict_hf = model_hf.state_dict()
convert_dict = convert_state_dict(state_dict_hf, amp=omega_cfg.megatron_amp_O2)

logging.info("Loading state dict...")
missing_keys, unexpected_keys = model.load_state_dict(convert_dict, strict=False)

if missing_keys:
# Keys ending with '_extra_state' are related to Transformer Engine internals
missing_keys_non_extra = [key for key in missing_keys if not key.endswith("_extra_state")]
if missing_keys_non_extra:
logging.critical("Missing keys were detected during the load, something has gone wrong. Aborting.")
raise RuntimeError(f"Missing keys: \n{missing_keys_non_extra}")

if unexpected_keys:
logging.critical("Unexpected keys were detected which should not happen. Aborting.")
raise RuntimeError(f"Unexpected keys: \n{unexpected_keys}")
logging.info("Creating Megatron model...")
model = load_state_dict_helper(MegatronGPTModel, omega_cfg, trainer, convert_dict)
logging.info(f"Created model:\n{model}")

logging.info("Saving model...")
# We make sure that the tokenizer can be instantiated later regardless of args.input
Expand Down
Loading