Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa committed Oct 23, 2024
1 parent 9841082 commit 11afd6d
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions examples/nlp/language_modeling/upcyle_dense_to_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
import torch.nn
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.utils import logging

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
Expand All @@ -44,13 +40,13 @@

def get_args():
parser = ArgumentParser()
parser.add_argument("--model", type=str, default=None, required=True, help="Path to NeMo checkpoint")
parser.add_argument(
"--model", type=str, default=None, required=True, help="Path to NeMo checkpoint"
"--output-path", type=str, default='', required=False, help="Path to NeMo save upcycled checkpoint"
)
parser.add_argument(
"--output-path", type=str, default='', required=False, help="Path to NeMo save upcycled checkpoint"
"--num-experts", type=int, default=8, required=True, help="Number of experts to use in upcycled model."
)
parser.add_argument("--num-experts", type=int, default=8, required=True, help="Number of experts to use in upcycled model.")
args = parser.parse_args()
assert isinstance(args.num_experts, int)
assert args.num_experts > 1, "Expected --num-experts to be greater-than 1."
Expand All @@ -61,10 +57,12 @@ def get_args():

def make_moe_config_from_dense(config, num_experts=8):
from copy import deepcopy

moe_config = deepcopy(config)
moe_config['num_moe_experts'] = num_experts
return moe_config


def upcycle(in_file, num_experts, cpu_only=True) -> None:
"""
Upcycle dense checkpoint to MoE.
Expand Down Expand Up @@ -98,15 +96,14 @@ def upcycle(in_file, num_experts, cpu_only=True) -> None:

# convert state dict dense -> MoE
from megatron.core.transformer.moe.upcycling_utils import upcycle_state_dict

moe_state_dict = upcycle_state_dict([moe_model.model.module], [dense_model.model.module])
moe_model.model.module.load_state_dict(moe_state_dict['model'])
moe_model._save_restore_connector = NLPSaveRestoreConnector()

moe_model.save_to(args.output_path)




if __name__ == '__main__':
args = get_args()
upcycle(args.model, args.num_experts)
Expand Down

0 comments on commit 11afd6d

Please sign in to comment.