diff --git a/examples/nlp/language_modeling/megatron_export.py b/examples/nlp/language_modeling/megatron_export.py index bf572c68309a..45dd05444bc5 100644 --- a/examples/nlp/language_modeling/megatron_export.py +++ b/examples/nlp/language_modeling/megatron_export.py @@ -30,7 +30,6 @@ from omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer -from torch.export import Dim from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel @@ -75,7 +74,6 @@ def nemo_export(cfg): assert nemo_in is not None, "NeMo model not provided. Please provide the path to the .nemo or .ckpt file" onnx_out = cfg.onnx_model_file - print(f"onnx_out: {onnx_out}") trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) assert ( cfg.trainer.devices * cfg.trainer.num_nodes @@ -154,11 +152,6 @@ def nemo_export(cfg): sequence = "sequence" batch = "batch" - use_dynamo = False - if use_dynamo: - sequence = Dim("sequence") - batch = Dim("batch") - model.export( onnx_out, onnx_opset_version=cfg.export_options.onnx_opset, @@ -171,7 +164,6 @@ def nemo_export(cfg): 'position_ids': {0: sequence, 1: batch}, 'logits': {0: sequence, 1: batch}, }, - use_dynamo=use_dynamo, ) except Exception as e: logging.error( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index dc4ea7353f39..4f9722d900f6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -196,7 +196,7 @@ def __init__(self, model): self.dtype = utils_funcs.torch_dtype_from_precision(model.cfg.precision) - def forward(self, input_ids, position_ids, attention_mask): + def forward(self, tokens, position_ids, attention_mask): if self.fp8_enabled and HAVE_TE: with ( transformer_engine.pytorch.onnx_export(self.fp8_enabled), @@ -207,12 +207,10 @@ def forward(self, input_ids, position_ids, attention_mask): warnings.catch_warnings(), ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') - assert input_ids.shape == position_ids.shape - assert ( - attention_mask.shape[2] == attention_mask.shape[3] == input_ids.shape[1] == position_ids.shape[1] - ) + assert tokens.shape == position_ids.shape + assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1] output_tensor = self.model.forward( - tokens=input_ids.cuda(), + tokens=tokens.cuda(), text_position_ids=position_ids.cuda(), attention_mask=attention_mask.cuda(), labels=None, @@ -225,12 +223,10 @@ def forward(self, input_ids, position_ids, attention_mask): warnings.catch_warnings(), ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') - assert input_ids.shape == position_ids.shape - assert ( - attention_mask.shape[2] == attention_mask.shape[3] == input_ids.shape[1] == position_ids.shape[1] - ) + assert tokens.shape == position_ids.shape + assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1] output_tensor = self.model.forward( - tokens=input_ids.cuda(), + tokens=tokens.cuda(), text_position_ids=position_ids.cuda(), attention_mask=attention_mask.cuda(), labels=None,