Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converter script fixes for mixtral/mistral #8272

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,7 @@ def build_transformer_config(self) -> TransformerConfig:
'fp8': fp8,
'tp_comm_overlap': ub_tp_comm_overlap,
# MoE related
'num_experts': self.cfg.get('num_experts', None),
'num_moe_experts': self.cfg.get('num_moe_experts', None),
'moe_router_load_balancing_type': self.cfg.get('moe_router_load_balancing_type', 'aux_loss'),
'moe_router_topk': self.cfg.get('moe_router_topk', 2),
'moe_grouped_gemm': self.cfg.get('moe_grouped_gemm', False),
Expand All @@ -1708,11 +1708,11 @@ def build_transformer_config(self) -> TransformerConfig:
'moe_input_jitter_eps': self.cfg.get('moe_input_jitter_eps', None),
'moe_token_dropping': self.cfg.get('moe_token_dropping', False), # TODO: Support token dropping.
}
if model_specific_configs['num_experts'] is not None:
if model_specific_configs['num_moe_experts'] is not None:
assert mcore_supports_moe(), 'Megatron-core >= v0.5.0 is required for MoE'
elif not mcore_supports_moe():
if 'num_experts' in model_specific_configs:
del model_specific_configs['num_experts']
if 'num_moe_experts' in model_specific_configs:
del model_specific_configs['num_moe_experts']
moe_keys = list(filter(lambda x: x.startswith('moe_'), model_specific_configs.keys()))
for k in moe_keys:
del model_specific_configs[k]
Expand Down
22 changes: 8 additions & 14 deletions scripts/nlp_language_modeling/convert_hf_mistral_7b_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from omegaconf import OmegaConf
from pytorch_lightning.core.saving import _load_state as ptl_load_state
from pytorch_lightning.trainer.trainer import Trainer
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM, AutoTokenizer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.nlp_overrides import (
Expand Down Expand Up @@ -127,23 +127,17 @@ def load_config(mistral_config, tokenizer_path):
return nemo_config


def load_mistral_ckpt(dir):
params_file = os.path.join(dir, 'config.json')
def load_mistral_ckpt(in_dir):
params_file = os.path.join(in_dir, 'config.json')
assert os.path.exists(params_file)
with open(params_file, 'r') as fp:
model_args = json.load(fp)

ckpt = OrderedDict()
ckpt['state_dict'] = OrderedDict()
for i in range(2):
ckpt_file = f'pytorch_model-0000{i+1}-of-00002.bin'
ckpt_path = os.path.join(dir, ckpt_file)
assert os.path.exists(ckpt_path)
ckpt.update(torch.load(ckpt_path))
tokenizer_file = os.path.join(dir, 'tokenizer.model')
assert os.path.exists(tokenizer_file)
tokenizer = SentencePieceProcessor(model_file=tokenizer_file)
assert tokenizer.get_piece_size() == model_args['vocab_size']
model = AutoModelForCausalLM.from_pretrained(in_dir)
ckpt = model.state_dict()

tokenizer = AutoTokenizer.from_pretrained(in_dir)
assert tokenizer.vocab_size == model_args['vocab_size']
return model_args, ckpt, tokenizer


Expand Down
9 changes: 6 additions & 3 deletions scripts/nlp_language_modeling/convert_hf_mixtral_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def load_model(cls, checkpoint, strict, **kwargs):

# register the artifacts
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
assert os.path.exists(
cfg.tokenizer.model
), f"Expected cfg.tokenizer.model {cfg.tokenizer.model} to be present"
if cfg.tokenizer.model is not None:
model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model)
if cfg.tokenizer.vocab_file is not None:
Expand Down Expand Up @@ -110,8 +113,8 @@ def load_config(mixtral_config, tokenizer_path):
if 'num_key_value_heads' in mixtral_config:
nemo_config.num_query_groups = mixtral_config['num_key_value_heads']

nemo_config.num_experts = int(mixtral_config['num_local_experts'])
assert nemo_config.num_experts > 0, "num_experts must be greater than zero."
nemo_config.num_moe_experts = int(mixtral_config['num_local_experts'])
assert nemo_config.num_moe_experts > 0, "num_experts must be greater than zero."
nemo_config.moe_router_topk = int(mixtral_config['num_experts_per_tok'])
assert nemo_config.moe_router_topk > 0, "moe_router_topk must be greater than zero."
nemo_config.use_cpu_initialization = True
Expand Down Expand Up @@ -266,7 +269,7 @@ def convert(args):
raise Exception("not implemented")
checkpoint['state_dict'][moe_gate_name] = param_to_weights(moe_gate)
# Handle experts
for i in range(nemo_config.num_experts):
for i in range(nemo_config.num_moe_experts):
gate_proj = ckpt[f'model.layers.{l}.block_sparse_moe.experts.{i}.w1.weight']
up_proj = ckpt[f'model.layers.{l}.block_sparse_moe.experts.{i}.w3.weight']
if mcore_gpt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def get_new_key(old_key):
convert_dict = convert_state_dict(state_dict_hf, amp=omega_cfg.megatron_amp_O2)

logging.info("Creating Megatron model...")
omega_cfg.cpu_offloading_num_layers = 0
model = load_state_dict_helper(MegatronGPTModel, omega_cfg, trainer, convert_dict)
logging.info(f"Created model:\n{model}")

Expand Down
Loading