Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 17, 2024
1 parent 780ff60 commit 4dc48ae
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions scripts/nlp_language_modeling/convert_nemo_mistral_7b_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,19 @@
import torch
import torch.nn
from pytorch_lightning.trainer.trainer import Trainer
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils import logging
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


def get_args():
parser = ArgumentParser()
parser.add_argument(
"--in-file", type=str, default=None, required=True, help="Path to NeMo Mistral-7B checkpoint"
)
parser.add_argument(
"--out-file", type=str, default=None, required=True, help="Path to output HF checkpoint."
)
parser.add_argument(
'--hf-model-name', type=str, default="mistralai/Mistral-7B-v0.1", help="Name of HF checkpoint"
)
parser.add_argument(
"--precision", type=str, default="32", help="Model precision"
)
parser.add_argument("--in-file", type=str, default=None, required=True, help="Path to NeMo Mistral-7B checkpoint")
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output HF checkpoint.")
parser.add_argument('--hf-model-name', type=str, default="mistralai/Mistral-7B-v0.1", help="Name of HF checkpoint")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -106,7 +97,6 @@ def convert(in_file, precision=None) -> None:
head_size = hidden_size // head_num
num_layers = nemo_config.num_layers


if precision is None:
precision = model.cfg.precision
if precision in [32, "32"]:
Expand Down Expand Up @@ -226,6 +216,7 @@ def convert(in_file, precision=None) -> None:
state_dict[hf_output_layer_weight_name] = param_to_weights(ckpt[output_layer_base_name])
return state_dict, nemo_config


if __name__ == '__main__':
args = get_args()
hf_state_dict, nemo_config = convert(args.in_file, args.precision)
Expand Down

0 comments on commit 4dc48ae

Please sign in to comment.