diff --git a/Dockerfile b/Dockerfile index c919281e1bee..ad200c76bb0b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,8 +42,13 @@ RUN apt-get update && \ libavdevice-dev && \ rm -rf /var/lib/apt/lists/* -WORKDIR /tmp/ +WORKDIR /workspace/ +# Install Megatron-core +RUN git clone https://github.com/aklife97/Megatron-LM.git && \ + cd Megatron-LM && \ + pip install -e . +WORKDIR /tmp/ # TODO: Remove once this Apex commit (2/24/23) is included in PyTorch # container RUN git clone https://github.com/NVIDIA/apex.git && \ diff --git a/Jenkinsfile b/Jenkinsfile index 4d14b18fad62..4055c3a608cb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -57,6 +57,14 @@ pipeline { } } + // TODO: remove when pip package is available + stage('Megatron Core installation') { + steps { + sh 'git clone https://github.com/aklife97/Megatron-LM.git && \ + cd Megatron-LM && \ + pip install -e .' + } + } stage('PyTorch Lightning version') { steps { diff --git a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py index 89c4a52d1a49..5ec767c34a10 100644 --- a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py @@ -28,7 +28,7 @@ from argparse import ArgumentParser import torch -from apex.transformer import parallel_state +from megatron.core import parallel_state from pytorch_lightning.plugins.environments import TorchElasticEnvironment from pytorch_lightning.trainer.trainer import Trainer @@ -121,9 +121,9 @@ def convert(local_rank, rank, world_size, args): app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=app_state.tensor_model_parallel_size, - pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, - pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank, + tensor_model_parallel_size=app_state.tensor_model_parallel_size, + pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, ) app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index a0a8c4db67a9..d797937850e0 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -33,11 +33,13 @@ from nemo.utils.model_utils import inject_model_parallel_rank try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False """ This is the script to run GPT text generation. diff --git a/examples/nlp/language_modeling/megatron_gpt_prompt_learning_eval.py b/examples/nlp/language_modeling/megatron_gpt_prompt_learning_eval.py index a75c5279c786..d66bac0bfecc 100644 --- a/examples/nlp/language_modeling/megatron_gpt_prompt_learning_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_prompt_learning_eval.py @@ -14,7 +14,7 @@ import torch import torch.multiprocessing as mp -from apex.transformer import parallel_state +from megatron.core import parallel_state from omegaconf import OmegaConf from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index da94cef168e7..5404e4b18f8b 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -42,6 +42,7 @@ from typing import Any, Optional import torch +from megatron.core import parallel_state from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml from pytorch_lightning.trainer.trainer import Trainer diff --git a/examples/nlp/language_modeling/megatron_retro_eval.py b/examples/nlp/language_modeling/megatron_retro_eval.py index e496d5b95f78..79b1e2debdfa 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_eval.py @@ -25,11 +25,13 @@ from nemo.core.config import hydra_runner try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False """ This is the script to run RETRO Model text generation. diff --git a/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py b/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py index 6c04d978d2a9..3b932e99ced3 100644 --- a/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_prompt_learning_eval.py @@ -26,11 +26,11 @@ from nemo.utils.app_state import AppState try: - from apex.transformer import parallel_state + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False if not torch.cuda.is_available(): diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_adapter_eval.py b/examples/nlp/language_modeling/tuning/megatron_gpt_adapter_eval.py index 9d4f27079f98..a4408b7b1c3d 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_adapter_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_adapter_eval.py @@ -15,7 +15,7 @@ import torch import torch.multiprocessing as mp -from apex.transformer import parallel_state +from megatron.core import parallel_state from omegaconf import OmegaConf from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_ia3_eval.py b/examples/nlp/language_modeling/tuning/megatron_gpt_ia3_eval.py index 4034a9810b6d..b780ff821c47 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_ia3_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_ia3_eval.py @@ -15,7 +15,7 @@ import torch import torch.multiprocessing as mp -from apex.transformer import parallel_state +from megatron.core import parallel_state from omegaconf import OmegaConf from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer diff --git a/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py b/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py index f987a6072184..bdeec2b5a9c1 100644 --- a/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py @@ -15,7 +15,7 @@ import torch import torch.multiprocessing as mp -from apex.transformer import parallel_state +from megatron.core import parallel_state from omegaconf import OmegaConf from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer diff --git a/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py b/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py index 80ffb0909a7f..8a8ddae166e1 100644 --- a/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_t5_ia3_eval.py @@ -15,7 +15,7 @@ import torch import torch.multiprocessing as mp -from apex.transformer import parallel_state +from megatron.core import parallel_state from omegaconf import OmegaConf from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer diff --git a/examples/nlp/machine_translation/megatron_nmt_training.py b/examples/nlp/machine_translation/megatron_nmt_training.py index 0d05d7329f1c..7fd211447196 100644 --- a/examples/nlp/machine_translation/megatron_nmt_training.py +++ b/examples/nlp/machine_translation/megatron_nmt_training.py @@ -13,6 +13,7 @@ # limitations under the License. +import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelSummary @@ -33,6 +34,8 @@ from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +mp.set_start_method("spawn", force=True) + @hydra_runner(config_path="conf", config_name="aayn_base_megatron") def main(cfg) -> None: diff --git a/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py b/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py index 8a1ebed77d3a..2a14aa5afc58 100644 --- a/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py +++ b/nemo/collections/nlp/data/glue_benchmark/glue_benchmark_dataset.py @@ -382,6 +382,7 @@ def __init__( max_seq_length_decoder: int = 128, use_cache: bool = True, prefix_override: str = None, + pad_to_max_length: bool = True, ): """ Processes GLUE datasets @@ -392,10 +393,12 @@ def __init__( max_seq_length: max sequence length minus 2 for [CLS] and [SEP] use_cache: whether to use data cache prefix_override: if you want to override default prompt for this task specify this via a string. + pad_to_max_length: If true, pad to the maximum length. """ super().__init__(file_name, task_name, tokenizer, max_seq_length, use_cache, compute_features=False) self.max_seq_length = max_seq_length self.max_seq_length_decoder = max_seq_length_decoder + self.pad_to_max_length = pad_to_max_length self.processor = processors[self.task_name]() self.prefix_override = prefix_override self.features = self.convert_examples_to_features() @@ -412,9 +415,16 @@ def collate_fn(self, batch): dec_input = [item['text_dec'] for item in batch] labels = [item['labels'] for item in batch] - max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0 max_enc_query_length = max([len(item) for item in enc_query]) if enc_query else 0 + max_dec_input_length = max([len(item) for item in dec_input]) if dec_input else 0 max_label_length = max([len(item) for item in labels]) if labels else 0 + if self.pad_to_max_length: + assert max_enc_query_length <= self.max_seq_length + assert max_dec_input_length <= self.max_seq_length_decoder + assert max_label_length <= self.max_seq_length_decoder + max_enc_query_length = self.max_seq_length + max_dec_input_length = self.max_seq_length_decoder + max_label_length = self.max_seq_length_decoder loss_mask = [([1] * (len(item))) + ([0] * (max_label_length - len(item))) for item in labels] enc_query = [item + [self.tokenizer.pad_id] * (max_enc_query_length - len(item)) for item in enc_query] @@ -488,10 +498,18 @@ def __init__( use_cache: bool = True, prefix_override: str = None, lang_list: List[str] = None, + pad_to_max_length: bool = True, ): self.lang_list = set(lang_list) super().__init__( - file_name, task_name, tokenizer, max_seq_length, max_seq_length_decoder, use_cache, prefix_override + file_name, + task_name, + tokenizer, + max_seq_length, + max_seq_length_decoder, + use_cache, + prefix_override, + pad_to_max_length, ) if len(lang_list) <= 0 or lang_list is None: raise ValueError(f"Found an empty or None lang_list for {self.task_name}") diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index e55e8ce52faf..d1f0718a6abd 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -54,13 +54,13 @@ from nemo.utils.get_rank import is_global_rank_zero try: - from apex.transformer import parallel_state + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False DSET_TYPE_BERT = 'standard_bert' diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py index 108cb01dceef..bf34626d4ea4 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py @@ -32,13 +32,13 @@ from nemo.utils import logging try: - from apex.transformer import parallel_state + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False def build_dataset(cfg, trainer, data_prefix, data_impl, num_samples, seq_length, seed, skip_warmup, tokenizer, name): @@ -303,9 +303,9 @@ def __init__( seed, drop_last=True, ): - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) super().__init__() @@ -432,9 +432,9 @@ class MockGPTDataset(Dataset): def __init__( self, cfg, tokenizer, name, num_samples, seq_length, seed, ): - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "Megatron core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) super().__init__() diff --git a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py index e0f23184e5d8..c9791bc0147d 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/megatron_batch_samplers.py @@ -69,9 +69,9 @@ def __init__( micro_batch_size: The size of each micro batch. global_batch_size: The size of global batch. data_parallel_rank: The value you can obtain via - `parallel_state.get_data_parallel_rank()` of apex.transformer. + `parallel_state.get_data_parallel_rank()` of megatron.core. data_parallel_size: The value you can obtain via - `parallel_state.get_data_parallel_world_size()` of apex.transformer. + `parallel_state.get_data_parallel_world_size()` of megatron.core. """ # Sanity checks. if total_samples <= 0: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 1dcf5609c85f..f0a501d7cc13 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -37,11 +37,13 @@ from nemo.utils import logging try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False __all__ = [ "RETRODataset", @@ -76,9 +78,9 @@ def __init__( knn_index: KNNIndex, retrieval_index: MMapRetrievalIndexedDataset, ): - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) super().__init__() diff --git a/nemo/collections/nlp/models/language_modeling/megatron/__init__.py b/nemo/collections/nlp/models/language_modeling/megatron/__init__.py index 0ede5e4a9cc9..3afb1e3fae48 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/__init__.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/__init__.py @@ -17,8 +17,8 @@ try: from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False # from nemo.collections.nlp.models.language_modeling.megatron.t5_model import T5Model diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py index 009eb2629dd2..464d69c72043 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert_model.py @@ -31,16 +31,26 @@ ) try: - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType from apex.transformer.tensor_parallel.layers import set_tensor_model_parallel_attributes HAVE_APEX = True except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + # fake missing classes with None attributes AttnMaskType = ApexGuardDefaults() +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + def bert_extended_attention_mask(attention_mask): # We create a 3D attention mask from a 2D tensor mask. diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index 7c61c4078eb9..d6af1960eae9 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -26,15 +26,26 @@ ) try: - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType HAVE_APEX = True + except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + # fake missing classes with None attributes AttnMaskType = ApexGuardDefaults() + HAVE_APEX = False + +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + def post_language_model_processing( lm_output, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 97b500ce30bf..3b223a5744af 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -36,13 +36,24 @@ from nemo.utils.get_rank import is_global_rank_zero try: - from apex.transformer import parallel_state from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ["MegatronBaseModel"] @@ -63,13 +74,15 @@ class MegatronBaseModel(NLPModel): """ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): - # FIXME: switch to self._cfg - if not HAVE_APEX: + + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) + if trainer is None: raise ValueError(f"Trainer cannot be None for Megatron-based models. Please provide a PTL trainer object.") + # this prevents base constructor from initializing tokenizer self.tokenizer = None @@ -298,8 +311,8 @@ def sync_overlap_parameters(self, params=None): if self.with_distributed_adam: self._optimizer._try_start_bucket_param_sync(params) - def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None: - super().on_train_batch_end(outputs, batch, batch_idx) + def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unused: Optional[int] = 0) -> None: + super().on_train_batch_end(outputs, dataloader_iter, batch_idx) # TODO: Replace with newer override for scheduler.step() instead of # search for plugins for fp16 GradScalar diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index 24d8998c3fbe..9e79cb4a41e7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import re from collections import OrderedDict +from typing import Any, Optional import torch from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.trainer.trainer import Trainer from torch import Tensor @@ -32,7 +35,26 @@ VirtualPromptStyle, ) from nemo.collections.nlp.modules.common.transformer.text_generation import TextGeneration -from nemo.utils import logging +from nemo.collections.nlp.parts.nlp_overrides import GradScaler +from nemo.utils import AppState, logging + +try: + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ['MegatronBasePromptLearningModel'] @@ -355,6 +377,29 @@ def setup_test_data(self, test_data_config=None): pin_memory=True, ) + def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gbs): + # This should happen only on the last batch of the dataset. + if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def _reconfigure_batch_sizes(self, gbs: int, mbs: int): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=gbs, + micro_batch_size=mbs, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + def set_inference_config(self, inference_config): self._inference_config = inference_config diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 608ab1017b9c..12d121ca2bcd 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -28,6 +28,7 @@ ) from nemo.collections.nlp.models.language_modeling.megatron.bert_model import BertModel from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, @@ -40,20 +41,24 @@ from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state - from apex.transformer.pipeline_parallel.schedules.common import build_model - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, - ) - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, - ) + from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + class MegatronBertModel(MegatronBaseModel): """ @@ -62,9 +67,9 @@ class MegatronBertModel(MegatronBaseModel): """ def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False) self.cfg = cfg @@ -176,17 +181,6 @@ def _validate_trainer(self): f'Gradient accumulation is done within training_step. trainer.accumulate_grad_batches must equal 1' ) - def _get_fwd_bwd_function(self): - fwd_bwd_function = None - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: - fwd_bwd_function = _forward_backward_pipelining_with_interleaving - else: - fwd_bwd_function = forward_backward_pipelining_without_interleaving - else: - fwd_bwd_function = forward_backward_no_pipelining - return fwd_bwd_function - def get_forward_output_and_loss_func(self): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): if parallel_state.get_pipeline_model_parallel_world_size() == 1: @@ -303,45 +297,20 @@ def training_step(self, dataloader_iter, batch_idx): tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] - # handle asynchronous grad reduction - custom_sync_context_handler = None - custom_grad_sync_func = None - custom_param_sync_func = None - if self.with_distributed_adam: - if self.megatron_amp_o2: - # copy grads to main grad - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) - else: - # keep grad tensors around - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) - custom_grad_sync_func = self.reduce_overlap_gradients - custom_param_sync_func = self.sync_overlap_parameters - else: - if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): - custom_sync_context_handler = self._optimizer.no_sync - else: - # TODO: enable async grad all reduce for O1/autocast mixed precision training - custom_sync_context_handler = None - # run forward and backwards passes for an entire global batch # we do this inside training_step to support pipeline parallelism - fwd_bwd_function = self._get_fwd_bwd_function() + fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(), - batch=dataloader_iter, - model=self.model, + data_iterator=dataloader_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), forward_only=False, tensor_shape=tensor_shape, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - custom_sync_context_handler=custom_sync_context_handler, - custom_grad_sync_func=custom_grad_sync_func, - custom_param_sync_func=custom_param_sync_func, - sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), - num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( - 'num_micro_batches_with_partial_activation_checkpoints', None - ), + sequence_parallel=self.cfg.get('sequence_parallel', False), ) if losses_reduced_per_micro_batch: @@ -431,16 +400,17 @@ def allreduce_first_last_embeddings(self): def validation_step(self, dataloader_iter, batch_idx): tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] - fwd_bwd_function = self._get_fwd_bwd_function() + fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(), - batch=dataloader_iter, - model=self.model, + data_iterator=dataloader_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), forward_only=True, tensor_shape=tensor_shape, dtype=self.autocast_dtype, - sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), + sequence_parallel=self.cfg.get('sequence_parallel', False), ) if losses_reduced_per_micro_batch: @@ -549,7 +519,7 @@ def build_train_valid_test_datasets(self): def backward(self, *args, **kwargs): """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. No need to call it here. """ return @@ -564,7 +534,7 @@ def _append_sequence_parallel_module_grads(self, module, grads): """ Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): - sequence_parallel_param = getattr(param, 'sequence_parallel_enabled', False) + sequence_parallel_param = getattr(param, 'sequence_parallel', False) if sequence_parallel_param: if self.megatron_amp_o2: grad = param.main_grad @@ -782,13 +752,13 @@ def configure_optimizers(self): # Disable overlapped grad sync for layer norm grads when # sequence parallelism is enabled for param in self.parameters(): - if getattr(param, 'sequence_parallel_enabled', False): + if getattr(param, 'sequence_parallel', False): param._disable_greedy_grad_copy = not self.megatron_amp_o2 param._disable_overlap_grad_sync = True # sequence parallelism is enabled for param in self.parameters(): - if getattr(param, 'sequence_parallel_enabled', False): + if getattr(param, 'sequence_parallel', False): param._disable_greedy_grad_copy = not self.megatron_amp_o2 param._disable_overlap_grad_sync = True @@ -869,41 +839,3 @@ def on_load_checkpoint(self, checkpoint) -> None: parallel_state.set_virtual_pipeline_model_parallel_rank(i) self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True) parallel_state.set_virtual_pipeline_model_parallel_rank(0) - - def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unused: Optional[int] = 0) -> None: - super().on_train_batch_end(outputs, dataloader_iter, batch_idx) - - # TODO: Replace with newer override for scheduler.step() instead of - # search for plugins for fp16 GradScalar - if self.trainer.precision_plugin is not None and isinstance( - self.trainer.precision_plugin, NativeMixedPrecisionPlugin - ): - precision_plugin = self.trainer.precision_plugin - - if ( - hasattr(precision_plugin, 'scaler') - and precision_plugin.scaler is not None - and isinstance(precision_plugin.scaler, GradScaler) - ): - grad_scaler = precision_plugin.scaler - - # If the grad scaler skipped its optimizer step due to infs/nans, - # decrement the step of all schedulers. - if grad_scaler.optimizer_update_skipped is not None and grad_scaler.optimizer_update_skipped is True: - scheduler_cfgs = self.trainer.lr_scheduler_configs - - if not scheduler_cfgs or not self.trainer.lightning_module.automatic_optimization: - return - - for scheduler_cfg in scheduler_cfgs: - # Decrement the counter by 2, then perform a scheduler.step() to perform a no-up - # as well as update the optimizer lr in all param groups - scheduler_cfg.scheduler.last_epoch -= 2 - scheduler_cfg.scheduler.step() - - # Removing the line below because it messes up train_valid_test_num_samples calculation. - # self.trainer.fit_loop.max_steps = self.trainer.fit_loop.max_steps + 1 - - # Reset the optimizer update skipped to `None` - this is to prevent scheduler no-ops during - # accumulated gradient updates. - grad_scaler.optimizer_update_skipped = None diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index 8ce8b22f02c1..452819e1d5c4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import json +from typing import Dict, List import torch from omegaconf import DictConfig, ListConfig @@ -22,17 +24,30 @@ from nemo.collections.common.metrics.classification_accuracy import ExactStringPerCategoryMatchMetric from nemo.collections.nlp.data.common.sequence_to_sequence_dataset import SequenceToSequenceDataset from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model, T5Sentinel -from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher +from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_num_microbatches + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + get_micro_batch_size, + get_num_microbatches, + ) HAVE_APEX = True except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ['MegatronT5FinetuneModel'] @@ -131,11 +146,6 @@ def setup(self, stage=None): if hasattr(self, '_train_ds'): self.setup_training_data() - def _process_global_batch(self, global_batch): - """Optionally processes a global batch.""" - # TODO: maybe remove this now that we've refactored data batch sizes. - return global_batch - def on_validation_epoch_start(self): app_state = AppState() _reconfigure_microbatch_calculator( @@ -192,25 +202,6 @@ def on_train_epoch_start(self) -> None: self.on_validation_epoch_end() return super().on_train_epoch_start() - def training_step(self, batch, batch_idx): - global_batch_size_per_gpu = batch['text_enc'].size(0) - # This should happen only on the last batch of the dataset. - if ( - global_batch_size_per_gpu - != self.cfg.data.train_ds.global_batch_size // parallel_state.get_data_parallel_world_size() - ): - # NOTE: This should never really be called since `drop_last=True` is required for training datasets. - app_state = AppState() - _reconfigure_microbatch_calculator( - rank=app_state.global_rank, - rampup_batch_size=None, - global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), - micro_batch_size=global_batch_size_per_gpu // get_num_microbatches(), - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - batch = self._process_global_batch(batch) - return super().training_step(batch, batch_idx) - def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_are_strings=False): if metric_name == 'exact_string_match': return pred, label @@ -280,36 +271,78 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config): data_parallel_size=parallel_state.get_data_parallel_world_size(), ) - processed_batch = self._process_global_batch(batch) - return processed_batch + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # Get seq length of batch + batch = next(dataloader_iter) + if isinstance(batch, dict): + # convert to list if not already converted. + batch = self._process_batch(batch) + + _, seq_length = batch[0].shape + _, dec_seq_length = batch[1].shape + tensor_shape = [seq_length, get_micro_batch_size(), self.hidden_size] + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=data_iter, + model=[self.enc_dec_model], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + tensor_shape=tensor_shape, + decoder_seq_length=dec_seq_length, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + sequence_parallel=self.cfg.get('sequence_parallel', False), + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # we're not on the last pipeline stage so no losses + loss_mean = torch.tensor(0.0).cuda() - def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): - # Regular finetuning datasets will return a list of dicts for each microbatch. But T0 datasets will return a single dict for the global batch. + return loss_mean + + def inference_step(self, dataloader_iter, batch_idx: int, mode: str, dataloader_idx=0): + # Regular finetuning datasets will return a list of dicts for each microbatch. + # But T0 datasets will return a single dict for the global batch. + batch = next(dataloader_iter) batch_has_lang_information = isinstance(batch, list) and len(batch[0]) == 7 data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds - processed_batch = self._reconfigure_and_process_inference_batch(batch, data_cfg) + self._reconfigure_and_process_inference_batch(batch, data_cfg) - # Call parent validation step to get the loss. - # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class. - loss = super().validation_step(processed_batch, batch_idx) + # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, + # this will be ignored. + loss = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True) predicted_token_ids, _ = self.decode( - tokens_enc=processed_batch['text_enc'], - enc_mask=processed_batch['enc_mask'], + tokens_enc=batch['text_enc'], + enc_mask=batch['enc_mask'], num_tokens_to_generate=30, bos_id=self.tokenizer.pad_id if data_cfg.replace_bos_with_pad else self.tokenizer.bos_id, ) # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. preds_text = MegatronT5FinetuneModel.ids_to_text(predicted_token_ids, self.tokenizer) - labels_text = MegatronT5FinetuneModel.ids_to_text(processed_batch['labels'], self.tokenizer) - input_text = MegatronT5FinetuneModel.ids_to_text(processed_batch['text_enc'], self.tokenizer) + labels_text = MegatronT5FinetuneModel.ids_to_text(batch['labels'], self.tokenizer) + input_text = MegatronT5FinetuneModel.ids_to_text(batch['text_enc'], self.tokenizer) if not batch_has_lang_information: categories = [None] * len(preds_text) else: - categories = processed_batch['lang'] + categories = batch['lang'] metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] assert len(categories) == len(preds_text) == len(labels_text) @@ -388,7 +421,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): loss_log_key = self._determine_log_key(data_cfg, dataloader_idx, "loss", mode) # Determine the key used to log the eval metric based on the user provided name of the dataset or the dataloader index. metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode) - self.log(loss_log_key, loss) + self.log(loss_log_key, loss, batch_size=1) metric_object = ( self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx] ) @@ -398,17 +431,17 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # GLUE case: if len(metric) == 1 and 'acc' in metric: metric = metric['acc'] - self.log(metric_log_key, metric) + self.log(metric_log_key, metric, batch_size=1) logging.info(f"{mode} {metric_name}: {metric}") # XNLI case where the metric dictionary contains the language and the computed metric as values. else: for k, v in metric.items(): if k != 'acc' and 'total' not in k: - self.log(metric_log_key + f'_{k}', v) + self.log(metric_log_key + f'_{k}', v, batch_size=1) logging.info(f"{mode} {metric_name} lang {k} : {v}") metric = metric['acc'] else: - self.log(metric_log_key, metric) + self.log(metric_log_key, metric, batch_size=1) logging.info(f"{metric_log_key}: {metric}") metric_object.reset() @@ -481,11 +514,11 @@ def inference_epoch_end(self, outputs, mode, data_cfg): averaged_metric = 0.0 if monitor_mode == 'max' else 1e5 if mode == 'validation': - self.log("validation_loss", averaged_loss) - self.log(f"validation_{self.val_metric_name}", averaged_metric) + self.log("validation_loss", averaged_loss, batch_size=1) + self.log(f"validation_{self.val_metric_name}", averaged_metric, batch_size=1) elif mode == 'test': - self.log("test_loss", averaged_loss) - self.log(f"test_{self.test_metric_name}", averaged_metric) + self.log("test_loss", averaged_loss, batch_size=1) + self.log(f"test_{self.test_metric_name}", averaged_metric, batch_size=1) return averaged_loss, averaged_metric @@ -495,14 +528,14 @@ def write_predictions_to_file(self, outputs, output_file_path_prefix): for i, p, l in zip(outputs['inputs'], outputs['preds'], outputs['labels']): f_json.write(json.dumps({'input': i, 'pred': p, 'label': l}) + '\n') - def validation_step(self, batch, batch_idx, dataloader_idx=0): - return self.inference_step(batch, batch_idx, 'validation', dataloader_idx) + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + return self.inference_step(dataloader_iter, batch_idx, 'validation', dataloader_idx) def validation_epoch_end(self, outputs): _ = self.inference_epoch_end(outputs, 'validation', self.cfg.data.validation_ds) - def test_step(self, batch, batch_idx, dataloader_idx=0): - return self.inference_step(batch, batch_idx, 'test', dataloader_idx) + def test_step(self, dataloader_iter, batch_idx, dataloader_idx=0): + return self.inference_step(dataloader_iter, batch_idx, 'test', dataloader_idx) def test_epoch_end(self, outputs): _ = self.inference_epoch_end(outputs, 'test', self.cfg.data.test_ds) @@ -536,6 +569,11 @@ def build_data_loader( ) def setup_training_data(self): + if not self.cfg.data.train_ds.drop_last: + raise AttributeError( + "`drop_last` is required for the training dataset to ensure each batch is the same micro-batch size." + "To set this, set the variable `data.train_ds.drop_last=True` in the config." + ) self._train_dl = self.build_data_loader( self._train_ds, global_batch_size=self.cfg.data.train_ds.global_batch_size, @@ -550,7 +588,7 @@ def setup_eval_data(self, datasets, data_cfg): for dataset in datasets: eval_dl = self.build_data_loader( dataset, - global_batch_size=data_cfg.global_batch_size, + global_batch_size=self.cfg.data.train_ds.global_batch_size, shuffle=data_cfg.shuffle, num_workers=data_cfg.num_workers, pin_memory=data_cfg.pin_memory, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py index 79ed7a1c5e8a..5cc0f7ea3a32 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py @@ -22,11 +22,13 @@ from nemo.utils import logging try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False __all__ = ['MegatronT5GLUEModel'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py index a9714b428b8e..aa7cd4652b0a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py @@ -17,6 +17,7 @@ # Adapted by: @adithyare +import itertools import os import torch @@ -194,7 +195,8 @@ def setup_optimizer_param_groups(self): logging.info(f'Optimizer groups set:\n{self.frozen_model.summarize()}') def get_forward_output_and_loss_func(self): - def fwd_output_and_loss_func(batch, model): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) for x in batch] input_ids, labels, loss_mask, position_ids, attention_mask, taskname_ids = batch output_tensor = model(input_ids, position_ids, attention_mask, taskname_ids, labels, inference=False) @@ -208,10 +210,11 @@ def loss_func(output_tensor): return fwd_output_and_loss_func - def training_step(self, batch, batch_idx): - # we zero grads here because we also call backward in the apex fwd/bwd functions + def training_step(self, dataloader_iter, batch_idx): + # we zero grads here because we also call backward in the megatron-core fwd/bwd functions self._optimizer.zero_grad() - loss_mean = self.fwd_bwd_step(batch, batch_idx, forward_only=False) + batch = next(dataloader_iter) + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) self.allreduce_gradients() ## logging @@ -222,12 +225,12 @@ def training_step(self, batch, batch_idx): if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: - self.log('loss_scale', loss_scale) + self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] - self.log('lr', lr, rank_zero_only=True) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) # Need to make sure the frozen model param learning rate stays 0.0 # so forceing lr to be 0.0 for gpt layers before param update diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 19dc724c4043..f1e231e96c68 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -29,6 +29,7 @@ from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, @@ -48,26 +49,30 @@ SamplingParam, TextGeneration, ) -from nemo.collections.nlp.parts.nlp_overrides import GradScaler, build_model_cpu +from nemo.collections.nlp.parts.nlp_overrides import GradScaler from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging try: - from apex.transformer import parallel_state - from apex.transformer.pipeline_parallel.schedules.common import build_model - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, - ) - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, - ) + from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + try: import transformer_engine @@ -87,6 +92,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): raise ImportError( "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) + + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) # this prevents base constructor from initializing tokenizer self.tokenizer = None super().__init__(cfg, trainer=trainer, no_lm_init=True) @@ -100,9 +110,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # build_model returns a list of modules which are used for interleaved pipeline parallelism if isinstance(self.trainer.accelerator, CPUAccelerator): - self.model = build_model_cpu( + self.model = build_model( model_provider_func=self.model_provider_func, wrap_with_ddp=False, + on_cpu=True, virtual_pipeline_model_parallel_size=self.cfg.get('virtual_pipeline_model_parallel_size', None), ) else: @@ -268,7 +279,7 @@ def configure_optimizers(self): # Disable overlapped grad sync for layer norm grads when # sequence parallelism is enabled for param in self.parameters(): - if getattr(param, 'sequence_parallel_enabled', False): + if getattr(param, 'sequence_parallel', False): param._disable_greedy_grad_copy = not self.megatron_amp_o2 param._disable_overlap_grad_sync = True @@ -310,17 +321,6 @@ def forward(self, tokens, text_position_ids, attention_mask, labels): output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels) return output_tensor - def _get_fwd_bwd_function(self): - fwd_bwd_function = None - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: - fwd_bwd_function = _forward_backward_pipelining_with_interleaving - else: - fwd_bwd_function = forward_backward_pipelining_without_interleaving - else: - fwd_bwd_function = forward_backward_no_pipelining - return fwd_bwd_function - def training_step(self, dataloader_iter, batch_idx): """ We pass the dataloader iterator function to the micro-batch scheduler. @@ -328,7 +328,7 @@ def training_step(self, dataloader_iter, batch_idx): in the micro-batch fwd function. """ - # we zero grads here because we also call backward in the apex fwd/bwd functions + # we zero grads here because we also call backward in the megatron-core fwd/bwd functions self._optimizer.zero_grad() if self.with_distributed_adam: @@ -351,48 +351,21 @@ def training_step(self, dataloader_iter, batch_idx): tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] - # handle asynchronous grad reduction - custom_sync_context_handler = None - custom_grad_sync_func = None - custom_param_sync_func = None - if self.with_distributed_adam: - if self.megatron_amp_o2: - # copy grads to main grad - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) - else: - # keep grad tensors around - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) - custom_grad_sync_func = self.reduce_overlap_gradients - custom_param_sync_func = self.sync_overlap_parameters - else: - if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False): - custom_sync_context_handler = self._optimizer.no_sync - else: - # TODO: enable async grad all reduce for O1/autocast mixed precision training - custom_sync_context_handler = None - # run forward and backwards passes for an entire global batch # we do this inside training_step to support pipeline parallelism - fwd_bwd_function = self._get_fwd_bwd_function() + fwd_bwd_function = get_forward_backward_func() + # TODO @akhattar: remove sync related stuff from config, add num_micro_batches_with_partial_activation_checkpoints when ready losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(), - batch=dataloader_iter, - model=self.model, + data_iterator=dataloader_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), forward_only=False, tensor_shape=tensor_shape, dtype=self.autocast_dtype, grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - custom_sync_context_handler=custom_sync_context_handler, - custom_grad_sync_func=custom_grad_sync_func, - custom_param_sync_func=custom_param_sync_func, - sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - num_micro_batches_with_partial_activation_checkpoints=self.cfg.get( - 'num_micro_batches_with_partial_activation_checkpoints', None - ), - overlap_p2p_comm=self.cfg.get('overlap_p2p_comm', False), - batch_p2p_comm=self.cfg.get('batch_p2p_comm', True), + sequence_parallel=self.cfg.get('sequence_parallel', False), ) # only the last stages of the pipeline return losses @@ -459,7 +432,7 @@ def training_step(self, dataloader_iter, batch_idx): def backward(self, *args, **kwargs): """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. No need to call it here. """ return @@ -474,10 +447,7 @@ def _append_sequence_parallel_module_grads(self, module, grads): """ Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): - if getattr(self, 'transformer_engine', False): - sequence_parallel_param = getattr(param, 'sequence_parallel', False) - else: - sequence_parallel_param = getattr(param, 'sequence_parallel_enabled', False) + sequence_parallel_param = getattr(param, 'sequence_parallel', False) if sequence_parallel_param: if self.megatron_amp_o2: grad = param.main_grad @@ -644,25 +614,24 @@ def validation_step(self, dataloader_iter, batch_idx): Our dataloaders produce a micro-batch and then we fetch a number of microbatches depending on the global batch size and model parallel size from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] # run forward passes for an entire global batch # we do this inside validation_step to support pipeline parallelism - fwd_bwd_function = self._get_fwd_bwd_function() + fwd_bwd_function = get_forward_backward_func() losses_reduced_per_micro_batch = fwd_bwd_function( forward_step_func=self.get_forward_output_and_loss_func(validation_step=True), - batch=dataloader_iter, - model=self.model, + data_iterator=dataloader_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), forward_only=True, tensor_shape=tensor_shape, dtype=self.autocast_dtype, - sequence_parallel_enabled=self.cfg.get('sequence_parallel', False), - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - batch_p2p_comm=self.cfg.get('batch_p2p_comm', True), + sequence_parallel=self.cfg.get('sequence_parallel', False), ) # only the last stage of the pipeline returns losses @@ -1033,41 +1002,3 @@ def parameters(self): return itertools.chain.from_iterable(module.parameters() for module in self.model) else: return self.model.parameters() - - def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unused: Optional[int] = 0) -> None: - super().on_train_batch_end(outputs, dataloader_iter, batch_idx) - - # TODO: Replace with newer override for scheduler.step() instead of - # search for plugins for fp16 GradScalar - if self.trainer.precision_plugin is not None and isinstance( - self.trainer.precision_plugin, NativeMixedPrecisionPlugin - ): - precision_plugin = self.trainer.precision_plugin - - if ( - hasattr(precision_plugin, 'scaler') - and precision_plugin.scaler is not None - and isinstance(precision_plugin.scaler, GradScaler) - ): - grad_scaler = precision_plugin.scaler - - # If the grad scaler skipped its optimizer step due to infs/nans, - # decrement the step of all schedulers. - if grad_scaler.optimizer_update_skipped is not None and grad_scaler.optimizer_update_skipped is True: - scheduler_cfgs = self.trainer.lr_scheduler_configs - - if not scheduler_cfgs or not self.trainer.lightning_module.automatic_optimization: - return - - for scheduler_cfg in scheduler_cfgs: - # Decrement the counter by 2, then perform a scheduler.step() to perform a no-up - # as well as update the optimizer lr in all param groups - scheduler_cfg.scheduler.last_epoch -= 2 - scheduler_cfg.scheduler.step() - - # Removing the line below because it messes up train_valid_test_num_samples calculation. - # self.trainer.fit_loop.max_steps = self.trainer.fit_loop.max_steps + 1 - - # Reset the optimizer update skipped to `None` - this is to prevent scheduler no-ops during - # accumulated gradient updates. - grad_scaler.optimizer_update_skipped = None diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index bb2d07d54650..acf96688b33a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -14,56 +14,55 @@ import itertools import os -import re from functools import partial from typing import Any, List, Optional, Union import torch -from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict from pytorch_lightning.trainer.trainer import Trainer -from torch import Tensor from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset from nemo.collections.nlp.metrics.prompt_learning_metrics import AccuracyScore, BLEUScore, ROUGEScores -from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import ( MegatronBasePromptLearningModel, ) from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.modules.common import ( - PromptEncoder, - PromptEncoderType, - VirtualPromptPlaceholderToken, - VirtualPromptSource, - VirtualPromptStyle, +from nemo.collections.nlp.modules.common import VirtualPromptPlaceholderToken, VirtualPromptSource, VirtualPromptStyle +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, ) -from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group from nemo.collections.nlp.modules.common.text_generation_utils import ( get_default_length_params, get_default_sampling_params, megatron_gpt_generate, ) -from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam, TextGeneration -from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPSaveRestoreConnector from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state, tensor_parallel - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, - ) - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_micro_batch_size + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches HAVE_APEX = True except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ['MegatronGPTPromptLearningModel'] @@ -149,6 +148,7 @@ def init_model(self, cfg: DictConfig, trainer: Trainer): ) # TODO: for backward compatibility (@adithyare) in general these tasks lists should be depricated self.virtual_prompt_style = VirtualPromptStyle(cfg.virtual_prompt_style) + self.model_type = ModelType.encoder_or_decoder if self.pipeline_parallel: assert ( @@ -286,38 +286,30 @@ def forward( return output - def fwd_bwd_step(self, batch, batch_idx, forward_only): + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Dataloader produces a global batch which is turned into an iterator of microbatches. + The iterator of microbatches is then piped through the pipeline using Core's fwd/bwd functions. """ # Get seq length of batch + batch = next(dataloader_iter) _, seq_length = batch[0].shape tensor_shape = [seq_length, get_micro_batch_size(), self.hidden_size] - - if self.pipeline_parallel: - losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch, - model=self, - forward_only=forward_only, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - sequence_parallel_enabled=self.cfg.get("sequence_parallel", False), - sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False), - ) - else: - losses_reduced_per_micro_batch = forward_backward_no_pipelining( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch, - model=self, - forward_only=forward_only, - tensor_shape=tensor_shape, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False), - ) + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=data_iter, + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + tensor_shape=tensor_shape, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + sequence_parallel=self.cfg.get('sequence_parallel', False), + ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: @@ -331,10 +323,11 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only): return loss_mean - def training_step(self, batch, batch_idx): - # we zero grads here because we also call backward in the apex fwd/bwd functions + def training_step(self, dataloader_iter, batch_idx): + # we zero grads here because we also call backward in the megatron-core fwd/bwd functions self._optimizer.zero_grad() - loss_mean = self.fwd_bwd_step(batch, batch_idx, forward_only=False) + batch = next(dataloader_iter) + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) self.allreduce_gradients() ## logging @@ -345,17 +338,17 @@ def training_step(self, batch, batch_idx): if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: - self.log('loss_scale', loss_scale) + self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] - self.log('lr', lr, rank_zero_only=True) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) return loss_mean def backward(self, *args, **kwargs): """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. No need to call it here. """ return @@ -366,23 +359,11 @@ def optimizer_zero_grad(self, *args, **kwargs): """ return - def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gbs): - # This should happen only on the last batch of the dataset. - if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): - # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. - app_state = AppState() - _reconfigure_microbatch_calculator( - rank=app_state.global_rank, - rampup_batch_size=None, - global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), - micro_batch_size=global_batch_size_per_gpu, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - - def validation_step(self, batch, batch_idx): + def validation_step(self, dataloader_iter, batch_idx): + batch = next(dataloader_iter) gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) self._reconfigure_and_process_inference_batch(batch[0].size(0), gbs) - loss_mean = self.fwd_bwd_step(batch, batch_idx, forward_only=True) + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True) if loss_mean.item == 0.0: loss_mean = [] @@ -425,16 +406,6 @@ def validation_step(self, batch, batch_idx): } return {'loss': loss_mean} - def _reconfigure_batch_sizes(self, gbs: int, mbs: int): - app_state = AppState() - _reconfigure_microbatch_calculator( - rank=app_state.global_rank, - rampup_batch_size=None, - global_batch_size=gbs, - micro_batch_size=mbs, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - def on_train_epoch_start(self) -> None: gbs = self.cfg.global_batch_size mbs = self.cfg.micro_batch_size @@ -460,7 +431,7 @@ def validation_epoch_end(self, outputs): # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True) + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) logging.info(f'val_loss: {averaged_loss}') if self.cfg.get("report_validation_metric", False): @@ -495,14 +466,14 @@ def validation_epoch_end(self, outputs): val_metric = torch.tensor(0.0).cuda() metric_name = '' - self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True) + self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True, batch_size=1) gbs = self.cfg.global_batch_size mbs = self.cfg.micro_batch_size self._reconfigure_batch_sizes(gbs, mbs) - def test_step(self, batch, batch_idx): - return self.validation_step(batch, batch_idx) + def test_step(self, dataloader_iter, batch_idx): + return self.validation_step(dataloader_iter, batch_idx) def test_epoch_end(self, outputs): averaged_loss = average_losses_across_data_parallel_group(outputs) @@ -512,7 +483,7 @@ def setup_training_data(self, training_data_config=None): if self.cfg.data.get('train_ds', None): max_seq_length = self.frozen_model.cfg.encoder_seq_length if "max_seq_length" in self.cfg.data and self.cfg.data.max_seq_length: - max_seq_length = self.cfg.data.max_seq_length + max_seq_length = min(self.cfg.data.max_seq_length, max_seq_length) self._train_ds, self._train_dl = self.build_virtual_prompt_dataset( data=self.cfg.data.train_ds, batch_size=self.cfg.global_batch_size, @@ -533,7 +504,7 @@ def setup_validation_data(self, validation_data_config=None): if self.cfg.data.get('validation_ds', None): max_seq_length = self.frozen_model.cfg.encoder_seq_length if "max_seq_length" in self.cfg.data and self.cfg.data.max_seq_length: - max_seq_length = self.cfg.data.max_seq_length + max_seq_length = min(self.cfg.data.max_seq_length, max_seq_length) self._validation_ds, self._validation_dl = self.build_virtual_prompt_dataset( data=self.cfg.data.validation_ds, batch_size=self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size), @@ -649,7 +620,8 @@ def set_input_tensor(self, input_tensor): self.frozen_model.model.set_input_tensor(input_tensor) def get_forward_output_and_loss_func(self): - def fwd_output_and_loss_func(batch, model): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) for x in batch] input_ids, labels, loss_mask, position_ids, attention_mask, taskname_ids = batch output_tensor = model(input_ids, position_ids, attention_mask, taskname_ids, labels, inference=False) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 37df8bbd4a57..d77256c0d813 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -30,13 +30,22 @@ from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator HAVE_APEX = True except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + __all__ = ['MegatronGPTSFTModel'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 35e0fe2259a8..dc89165fc2af 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -15,7 +15,7 @@ import copy import functools import inspect -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch from omegaconf import OmegaConf, open_dict @@ -23,17 +23,17 @@ from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.trainer.trainer import Trainer -from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( - MegatronPretrainingBatchSampler, - MegatronPretrainingRandomBatchSampler, +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronPretrainingRandomSampler, + MegatronPretrainingSampler, ) from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( MegatronTokenLevelEncoderDecoderModule, ) from nemo.collections.nlp.modules.common.megatron.utils import ( - ApexGuardDefaults, average_losses_across_data_parallel_group, get_params_for_weight_decay_optimization, ) @@ -46,13 +46,6 @@ from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state, tensor_parallel - from apex.transformer.enums import ModelType - from apex.transformer.pipeline_parallel.schedules.common import build_model - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, - ) from apex.transformer.pipeline_parallel.utils import ( _reconfigure_microbatch_calculator, get_micro_batch_size, @@ -60,10 +53,21 @@ ) HAVE_APEX = True + except (ImportError, ModuleNotFoundError): - ModelType = ApexGuardDefaults() + HAVE_APEX = False +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False __all__ = ["MegatronLMEncoderDecoderModel"] @@ -91,13 +95,14 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Make sure trainer.accumulate_grad_batches is 1. self._validate_trainer() - # TODO: Not sure how to use lists of modules with PTL. + # TODO: Currently does not support interleaved pipeline parallelism. # This means we can only use pipeline parallelism without the interleaved schedule. if isinstance(self.trainer.accelerator, CPUAccelerator): logging.warning("Using CPUAccelerator, model will be built on CPU.") - self.enc_dec_model = nlp_overrides.build_model_cpu( + self.enc_dec_model = build_model( model_provider_func=self.model_provider_func, wrap_with_ddp=False, + on_cpu=True, model_type=ModelType.encoder_and_decoder, )[0] else: @@ -303,74 +308,27 @@ def forward( return output_tensor - def training_step(self, batch, batch_idx): + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ - # we zero grads here because we also call backward in the apex fwd/bwd functions - self._optimizer.zero_grad() - # we prepare the micro batches for the apex fwd/bwd function - batch_for_pipeline = self.process_global_batch(batch) - encoder_seq_length = batch_for_pipeline[0].size(1) - decoder_seq_length = batch_for_pipeline[1].size(1) - tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] - - # handle asynchronous grad reduction - custom_sync_context_handler = None - custom_grad_sync_func = None - if self.with_distributed_adam: - if self.megatron_amp_o2: - # copy grads to main grad - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True) - else: - # keep grad tensors around - custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False) - custom_grad_sync_func = self.reduce_overlap_gradients - else: - if ( - self.megatron_amp_o2 - and self.cfg.get('pipeline_model_parallel_size', 1) == 1 - and not self.cfg.get('sequence_parallel', False) - ): - custom_sync_context_handler = self._optimizer.no_sync - else: - # TODO: enable async grad all reduce with O1/autocast - # mixed precision training, with pipeline parallelism, - # or with sequence parallelism - custom_sync_context_handler = None - - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=False, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - custom_sync_context_handler=custom_sync_context_handler, - custom_grad_sync_func=custom_grad_sync_func, - ) - else: - losses_reduced_per_micro_batch = forward_backward_no_pipelining( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=False, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - custom_sync_context_handler=custom_sync_context_handler, - ) + # Get seq length of batch + tensor_shape = [self.max_encoder_seq_length, self.cfg.micro_batch_size, self.cfg.encoder.hidden_size] + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=dataloader_iter, + model=[self.enc_dec_model], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + tensor_shape=tensor_shape, + decoder_seq_length=self.max_decoder_seq_length, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: @@ -379,7 +337,26 @@ def training_step(self, batch, batch_idx): loss_tensor = torch.concat(loss_tensors_list) loss_mean = loss_tensor.mean() else: - loss_mean = torch.tensor(0.0).cuda() + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def training_step(self, dataloader_iter, batch_idx): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # we zero grads here because we also call backward in the megatron fwd/bwd functions + self._optimizer.zero_grad() + + loss_mean = self.fwd_bwd_step(dataloader_iter, batch_idx, False) if self.with_distributed_adam: # synchronize asynchronous grad reductions @@ -408,25 +385,38 @@ def training_step(self, batch, batch_idx): if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: - self.log('loss_scale', loss_scale) + self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] - self.log('lr', lr, rank_zero_only=True) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log( + 'global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1, + ) # TODO: make sure compute_consumed_samples works for pipeline parallelism self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, rank_zero_only=True, + batch_size=1, ) - return loss_mean + @property + def max_decoder_seq_length(self) -> int: + seq_len = self._cfg.data.get('seq_length_dec', None) + if seq_len is None: + seq_len = self.cfg.seq_length + return seq_len + + @property + def max_encoder_seq_length(self) -> int: + return self.cfg.seq_length + def backward(self, *args, **kwargs): """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. No need to call it here. """ return @@ -548,8 +538,26 @@ def allreduce_word_and_position_embeddings(self): grad, group=parallel_state.get_decoder_relative_position_embedding_group() ) + def _process_batch(self, global_batch: Dict[str, torch.Tensor]) -> List[torch.Tensor]: + # If the decoder input starts with instead of , which is the case for huggingface T5 models, we don't want to mask the first token. + # For NeMo-Megatron, the sequence starts with , which is never masked so we can always set index 0 to be unmasked. + global_batch['dec_mask'][:, 0] = 1 + + return [ + global_batch["text_enc"], + global_batch["text_dec"], + global_batch["loss_mask"], + global_batch["labels"], + global_batch["enc_mask"], + global_batch["dec_mask"], + ] + def get_forward_output_and_loss_func(self): - def fwd_output_and_loss_func(batch, model): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + if isinstance(batch, dict): + # convert to list if not already converted. + batch = self._process_batch(batch) batch = [x.cuda(non_blocking=True) for x in batch] encoder_input_ids, decoder_input_ids, loss_mask, lm_labels, encoder_attn_mask, decoder_attn_mask = batch @@ -629,7 +637,8 @@ def _get_forward_output_only_func(self, arg_names, output_name, **kwargs): kwargs - shared arguments (non tensors) """ - def fwd_output_only_func(batch, model): + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) for x in batch] # map batch and shared args into forward args @@ -643,109 +652,11 @@ def id_func(output_tensor): return fwd_output_only_func - def validation_step_logits(self, batch, batch_idx): - """ - return_values - if given, returns a dictionary with given keys and corresponding values - """ - batch_for_pipeline = self.process_global_batch(batch) - encoder_seq_length = batch_for_pipeline[0].size(1) - decoder_seq_length = batch_for_pipeline[1].size(1) - - tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] - - ( - encoder_input_ids, - decoder_input_ids, - loss_mask, - lm_labels, - encoder_attn_mask, - decoder_attn_mask, - ) = batch_for_pipeline - batch_for_pipeline = [encoder_input_ids, encoder_attn_mask, decoder_input_ids, decoder_attn_mask] - arg_names = ['enc_input_ids', 'enc_attn_mask', 'dec_input_ids', 'dec_attn_mask'] - - forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits") - - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - output_tensor = forward_backward_pipelining_without_interleaving( - forward_step_func=forward_step_func, - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - ) - else: - output_tensor = forward_backward_no_pipelining( - forward_step_func=forward_step_func, - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - ) - - if output_tensor: - # average loss across micro batches - logits_tensors_list = [o['logits'] for o in output_tensor] - logits_tensor = torch.concat(logits_tensors_list) - else: - # we're not on the last pipeline stage so no losses - logits_tensor = [] - - return logits_tensor - - def validation_step(self, batch, batch_idx, dataloader_idx=0): + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): """ return_values - if given, returns a dictionary with given keys and corresponding values """ - batch_for_pipeline = self.process_global_batch(batch) - encoder_seq_length = batch_for_pipeline[0].size(1) - decoder_seq_length = batch_for_pipeline[1].size(1) - - tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] - - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - ) - else: - losses_reduced_per_micro_batch = forward_backward_no_pipelining( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - ) - - if losses_reduced_per_micro_batch: - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() - else: - # we're not on the last pipeline stage so no losses - loss_mean = [] - - return loss_mean + return self.fwd_bwd_step(dataloader_iter, batch_idx, True) def validation_epoch_end(self, outputs): if parallel_state.is_pipeline_last_stage(): @@ -756,13 +667,13 @@ def validation_epoch_end(self, outputs): # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) return averaged_loss - def test_step(self, batch, batch_idx): - return self.validation_step(batch, batch_idx) + def test_step(self, dataloader_iter, batch_idx): + return self.validation_step(dataloader_iter, batch_idx) def test_epoch_end(self, outputs): if parallel_state.is_pipeline_last_stage(): @@ -773,7 +684,7 @@ def test_epoch_end(self, outputs): # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('test_loss', averaged_loss, prog_bar=True, rank_zero_only=True) + self.log('test_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) return averaged_loss def loss_func(self, loss_mask, tokens_loss): @@ -803,7 +714,7 @@ def process_micro_batch(self, micro_batch): return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask def _process_global_batch_without_megatron_batch_sampler(self, global_batch, tokenizer=None): - """ Prepares the global batch for apex fwd/bwd functions. + """ Prepares the global batch for megatron-core fwd/bwd functions. Global batch is a list of micro batches. """ tokenizer = self.tokenizer if tokenizer is None else tokenizer @@ -867,20 +778,6 @@ def _process_global_batch_without_megatron_batch_sampler(self, global_batch, tok 'dec_mask': dec_mask_tensor, } - def process_global_batch(self, global_batch): - # If the decoder input starts with instead of , which is the case for huggingface T5 models, we don't want to mask the first token. - # For NeMo-Megatron, the sequence starts with , which is never masked so we can always set index 0 to be unmasked. - global_batch['dec_mask'][:, 0] = 1 - - return [ - global_batch["text_enc"], - global_batch["text_dec"], - global_batch["loss_mask"], - global_batch["labels"], - global_batch["enc_mask"], - global_batch["dec_mask"], - ] - def build_train_valid_test_datasets(self): raise NotImplementedError("Please implement this method in child-class") @@ -894,7 +791,7 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers): # Megatron sampler if hasattr(self._cfg.data, 'dataloader_type') and self._cfg.data.dataloader_type is not None: if self._cfg.data.dataloader_type == 'single': - batch_sampler = MegatronPretrainingBatchSampler( + batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, @@ -904,11 +801,11 @@ def build_pretraining_data_loader(self, dataset, consumed_samples, num_workers): drop_last=self._cfg.get('drop_last', True), ) elif self._cfg.data.dataloader_type == 'cyclic': - batch_sampler = MegatronPretrainingRandomBatchSampler( + batch_sampler = MegatronPretrainingRandomSampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self._cfg.micro_batch_size, - global_batch_size=self._cffg.global_batch_size, + global_batch_size=self._cfg.global_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=self._cfg.get('drop_last', True), @@ -1075,29 +972,22 @@ def dummy(): arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True ) - # Counter intuitively, we need to set decoder_sequence_length=encoder_seq_length because while running `.enocde()`, the last hidden states from encoder are passed through as identity through the pipeline. Setting it to anything else will cause hanging due to tensor shape mismatches. - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - output_tensor = forward_backward_pipelining_without_interleaving( - forward_step_func=forward_step_func, - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=encoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - ) - else: - output_tensor = forward_backward_no_pipelining( - forward_step_func=forward_step_func, - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=encoder_seq_length, - dtype=self.autocast_dtype, - sync_batch_comm=self.cfg.get('sync_batch_comm', False), - ) + fwd_bwd_func = get_forward_backward_func() + + # Counter intuitively, we need to set decoder_sequence_length=encoder_seq_length + # because while running `.encode()`, the last hidden states from encoder are passed through + # as identity through the pipeline. + # Setting it to anything else will cause hanging due to tensor shape mismatches. + output_tensor = fwd_bwd_func( + forward_step_func=forward_step_func, + data_iterator=iter([batch_for_pipeline,]), + model=[self.enc_dec_model], + forward_only=True, + tensor_shape=tensor_shape, + num_microbatches=1, + decoder_seq_length=encoder_seq_length, + dtype=self.autocast_dtype, + ) if output_tensor: output_tensor = output_tensor[0]['hiddens'] @@ -1177,7 +1067,7 @@ def decode( logging.info(f'Decoding using the {sampling_method} method...') # Check whether the DDP is initialized. This is needed when running inference outside of training loop. - if parallel_state.is_unitialized(): + if not parallel_state.model_parallel_is_initialized(): def dummy(): return @@ -1249,26 +1139,18 @@ def dummy(): arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask'] forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits") - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - output_tensor = forward_backward_pipelining_without_interleaving( - forward_step_func=forward_step_func, - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - ) - else: - output_tensor = forward_backward_no_pipelining( - forward_step_func=forward_step_func, - batch=batch_for_pipeline, - model=self.enc_dec_model, - forward_only=True, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_seq_length, - dtype=self.autocast_dtype, - ) + fwd_bwd_func = get_forward_backward_func() + + output_tensor = fwd_bwd_func( + forward_step_func=forward_step_func, + data_iterator=iter([batch_for_pipeline,]), + model=[self.enc_dec_model], + forward_only=True, + tensor_shape=tensor_shape, + num_microbatches=1, + decoder_seq_length=encoder_seq_length, + dtype=self.autocast_dtype, + ) # get output tensor if parallel_state.is_pipeline_last_stage(): output_tensor = output_tensor[0]['logits'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py index d1c7b32ae77c..a9c659e48696 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py @@ -36,7 +36,6 @@ MegatronRetrievalTokenLevelEncoderDecoderModule, ) from nemo.collections.nlp.modules.common.megatron.utils import ( - ApexGuardDefaults, average_losses_across_data_parallel_group, build_position_ids, get_params_for_weight_decay_optimization, @@ -60,13 +59,14 @@ from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state - from apex.transformer.enums import ModelType + from megatron.core import parallel_state + from megatron.core.enums import ModelType + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - ModelType = ApexGuardDefaults() - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False __all__ = ["MegatronRetrievalModel"] @@ -268,7 +268,7 @@ def training_step(self, batch, batch_idx): if self.cfg.precision == 16: loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: - self.log('loss_scale', loss_scale) + self.log('loss_scale', loss_scale, batch_size=1) if self.with_distributed_adam: # gradients are reduced internally in distributed optimizer @@ -292,55 +292,19 @@ def training_step(self, batch, batch_idx): if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0: # Reduced loss for logging. average_reduced_loss = sum(self._reduced_loss_buffer) / len(self._reduced_loss_buffer) - self.log('reduced_train_loss', average_reduced_loss, prog_bar=True) + self.log('reduced_train_loss', average_reduced_loss, prog_bar=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] - self.log('lr', lr) - self.log('global_step', self.trainer.global_step, prog_bar=True) + self.log('lr', lr, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1) self.log( 'consumed_samples', self.compute_consumed_samples(self.trainer.global_step - self.init_global_step), prog_bar=True, + batch_size=1, ) self._reduced_loss_buffer = [] return lm_loss - def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None: - super().on_train_batch_end(outputs, batch, batch_idx) - - # TODO: Replace with newer override for scheduler.step() instead of - # search for plugins for fp16 GradScalar - if self.trainer.precision_plugin is not None and isinstance( - self.trainer.precision_plugin, NativeMixedPrecisionPlugin - ): - precision_plugin = self.trainer.precision_plugin - - if ( - hasattr(precision_plugin, 'scaler') - and precision_plugin.scaler is not None - and isinstance(precision_plugin.scaler, GradScaler) - ): - grad_scaler = precision_plugin.scaler - - # If the grad scaler skipped its optimizer step due to infs/nans, - # decrement the step of all schedulers. - if grad_scaler.optimizer_update_skipped is not None and grad_scaler.optimizer_update_skipped is True: - scheduler_cfgs = self.trainer.lr_scheduler_configs - - if not scheduler_cfgs or not self.trainer.lightning_module.automatic_optimization: - return - - for scheduler_cfg in scheduler_cfgs: - # Decrement the counter by 2, then perform a scheduler.step() to perform a no-up - # as well as update the optimizer lr in all param groups - scheduler_cfg.scheduler.last_epoch -= 2 - scheduler_cfg.scheduler.step() - - # Increase the max step count by 1 - - # Reset the optimizer update skipped to `None` - this is to prevent scheduler no-ops during - # accumulated gradient updates. - grad_scaler.optimizer_update_skipped = None - def validation_step(self, batch, batch_idx): input_tokens_id = batch['tokens'] input_attn_mask = batch['tokens_mask'] @@ -369,10 +333,10 @@ def validation_epoch_end(self, outputs): if len(outputs) == 0: return averaged_loss = torch.stack(outputs).mean() - self.log('val_loss', averaged_loss, prog_bar=True) + self.log('val_loss', averaged_loss, prog_bar=True, batch_size=1) # formula to compute the perplexity # https://towardsdatascience.com/the-relationship-between-perplexity-and-entropy-in-nlp-f81888775ccc - self.log('perplexity', torch.exp(averaged_loss), prog_bar=True) + self.log('perplexity', torch.exp(averaged_loss), prog_bar=True, batch_size=1) return averaged_loss def test_step(self, batch, batch_idx): @@ -380,9 +344,9 @@ def test_step(self, batch, batch_idx): def test_epoch_end(self, outputs): averaged_loss = torch.stack(outputs).mean() - self.log('test_loss', averaged_loss, prog_bar=True) + self.log('test_loss', averaged_loss, prog_bar=True, batch_size=1) logging.info(f'test_loss: {averaged_loss} ') - self.log('perplexity', torch.exp(averaged_loss), prog_bar=True) + self.log('perplexity', torch.exp(averaged_loss), prog_bar=True, batch_size=1) return averaged_loss def build_train_valid_test_datasets(self): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py index 89aa13bb55d5..1eaec4238648 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py @@ -37,12 +37,13 @@ from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_num_microbatches + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False __all__ = ['MegatronRetroFinetuneModel'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py index 826524ca717c..a663eb10fa3e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py @@ -27,13 +27,24 @@ from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ['MegatronT0Model'] @@ -142,8 +153,8 @@ def _build_dataset(self, data_cfg, check_implict_grad_acc=False, is_train=True): else: return datasets - def training_step(self, batch, batch_idx): - return super(MegatronT5FinetuneModel, self).training_step(batch, batch_idx) + def training_step(self, dataloader_iter, batch_idx): + return super(MegatronT5FinetuneModel, self).training_step(dataloader_iter, batch_idx) # Override the parent batch reconfiguring logic. def _reconfigure_and_process_inference_batch(self, batch): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py index 418bd70d2a75..71b3d5537efd 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py @@ -43,12 +43,12 @@ from nemo.utils import logging, model_utils try: - from apex.transformer import parallel_state + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False class MegatronT5BaseAdapterModel(MegatronT5PromptLearningModel): @@ -143,14 +143,15 @@ def compute_accuracy(self, enc_input, enc_mask, encoder_input, labels): 'enc_inputs': processed_inputs, } - def validation_step(self, batch, batch_idx, inference=False): + def validation_step(self, dataloader_iter, batch_idx, inference=False): + batch = next(dataloader_iter) enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids = batch mode = self.training self.eval() gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) self._reconfigure_and_process_inference_batch(enc_input.size(0), gbs) - loss_mean = self.fwd_bwd_step(batch, batch_idx, forward_only=True) + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True) if self.cfg.get('report_validation_metric', False): metrics = self.compute_accuracy(enc_input, enc_mask, labels) @@ -283,13 +284,13 @@ def validation_epoch_end(self, outputs): # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True) + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) logging.info(f'Validation loss: {averaged_loss}') else: averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() logging.info(f'Validation loss: {averaged_loss}') - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True) + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) if self.cfg.get('report_validation_accuracy', False): gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] @@ -324,7 +325,7 @@ def validation_epoch_end(self, outputs): else: val_acc = torch.tensor(0.0).cuda() - self.log('val_acc', val_acc, prog_bar=True, rank_zero_only=True) + self.log('val_acc', val_acc, prog_bar=True, rank_zero_only=True, batch_size=1) gbs = self.cfg.global_batch_size mbs = self.cfg.micro_batch_size diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index ea6bfd848bef..bfcc2c43631d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -27,26 +27,38 @@ ) from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model -from nemo.collections.nlp.modules.common import VirtualPromptSource -from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, +) from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging try: - from apex.transformer import parallel_state - from apex.transformer.enums import ModelType - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + get_micro_batch_size, + get_num_microbatches, ) - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_micro_batch_size HAVE_APEX = True except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +try: + from megatron.core import parallel_state + from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ['MegatronT5PromptLearningModel'] @@ -73,6 +85,7 @@ class MegatronT5PromptLearningModel(MegatronBasePromptLearningModel): def __init__(self, cfg: DictConfig, trainer: Trainer): super().__init__(cfg, trainer) + self.model_type = ModelType.encoder_and_decoder def first_stage_of_pipeline(self): if self.frozen_model.enc_dec_model.pre_process and parallel_state.get_pipeline_model_parallel_rank() == 0: @@ -159,41 +172,32 @@ def load_frozen_model(self, cfg, trainer): save_restore_connector=NLPSaveRestoreConnector(), ) - def fwd_bwd_step(self, batch, batch_idx, forward_only): + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ # Get seq length of batch + batch = next(dataloader_iter) _, seq_length = batch[0].shape _, dec_seq_length = batch[1].shape tensor_shape = [seq_length, get_micro_batch_size(), self.hidden_size] - - if self.cfg.get('pipeline_model_parallel_size', 1) > 1: - losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch, - model=self, - forward_only=forward_only, - tensor_shape=tensor_shape, - decoder_sequence_length=dec_seq_length, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - sequence_parallel_enabled=False, - sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False), - ) - else: - losses_reduced_per_micro_batch = forward_backward_no_pipelining( - forward_step_func=self.get_forward_output_and_loss_func(), - batch=batch, - model=self, - forward_only=forward_only, - tensor_shape=tensor_shape, - decoder_sequence_length=dec_seq_length, - dtype=self.autocast_dtype, - grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, - sync_batch_comm=self.frozen_model.cfg.get('sync_batch_comm', False), - ) + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=data_iter, + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + tensor_shape=tensor_shape, + decoder_seq_length=dec_seq_length, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + sequence_parallel=self.cfg.get('sequence_parallel', False), + ) # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: @@ -208,7 +212,8 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only): return loss_mean def get_forward_output_and_loss_func(self): - def fwd_output_and_loss_func(batch, model): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) for x in batch] enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids = batch @@ -228,7 +233,7 @@ def loss_func(output_tensor): def backward(self, *args, **kwargs): """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. No need to call it here. """ return @@ -248,16 +253,6 @@ def set_input_tensor(self, input_tensor): forward_step_func""" self.frozen_model.enc_dec_model.set_input_tensor(input_tensor) - def _reconfigure_batch_sizes(self, gbs: int, mbs: int): - app_state = AppState() - _reconfigure_microbatch_calculator( - rank=app_state.global_rank, - rampup_batch_size=None, - global_batch_size=gbs, - micro_batch_size=mbs, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - def on_train_epoch_start(self) -> None: gbs = self.cfg.global_batch_size mbs = self.cfg.micro_batch_size @@ -270,9 +265,10 @@ def on_validation_epoch_start(self) -> None: self._reconfigure_batch_sizes(gbs, mbs) return super().on_validation_epoch_start() - def training_step(self, batch, batch_idx): + def training_step(self, dataloader_iter, batch_idx): self._optimizer.zero_grad() - loss_mean = self.fwd_bwd_step(batch, batch_idx, forward_only=False) + batch = next(dataloader_iter) + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) self.allreduce_gradients() ## logging @@ -283,27 +279,14 @@ def training_step(self, batch, batch_idx): if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): loss_scale = self.trainer.precision_plugin.scaler._scale if loss_scale is not None: - self.log('loss_scale', loss_scale) + self.log('loss_scale', loss_scale, batch_size=1) - self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) lr = self._optimizer.param_groups[0]['lr'] - self.log('lr', lr, rank_zero_only=True) - self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True) + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) return loss_mean - def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gbs): - # This should happen only on the last batch of the dataset. - if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): - # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. - app_state = AppState() - _reconfigure_microbatch_calculator( - rank=app_state.global_rank, - rampup_batch_size=None, - global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), - micro_batch_size=global_batch_size_per_gpu, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - def get_predictions(self, input_ids, enc_mask, encoder_input, labels): predicted_token_ids, log_probs = self.frozen_model.decode( tokens_enc=input_ids, @@ -326,12 +309,12 @@ def get_predictions(self, input_ids, enc_mask, encoder_input, labels): def validation_step(self, batch, batch_idx, inference=False): input_ids, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids = batch - + # does not use dataloader_iter due to device placement issues arising from PTL mode = self.training self.eval() gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) self._reconfigure_and_process_inference_batch(input_ids.size(0), gbs) - loss_mean = self.fwd_bwd_step(batch, batch_idx, forward_only=True) + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=True) if self.first_stage_of_pipeline(): # Get embeddings for text tokens and insert virtual token embeddings @@ -368,13 +351,13 @@ def validation_epoch_end(self, outputs): # we can only log on one rank if it is rank zero so we broadcast from last rank torch.distributed.broadcast(averaged_loss, get_last_rank()) - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True) + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) logging.info(f'Validation loss: {averaged_loss}') else: averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() logging.info(f'Validation loss: {averaged_loss}') - self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True) + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) if self.cfg.get("report_validation_metric", False): gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] @@ -410,7 +393,7 @@ def validation_epoch_end(self, outputs): val_metric = torch.tensor(0.0).cuda() metric_name = '' - self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True) + self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True, batch_size=1) gbs = self.cfg.global_batch_size mbs = self.cfg.micro_batch_size diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 99791467cee5..fbf4b029fcdc 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -45,19 +45,35 @@ from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb +from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.core.classes import Exportable from nemo.utils import AppState, logging, timers try: - from apex.transformer import parallel_state - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + get_micro_batch_size, + get_num_microbatches, + ) HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ["MegatronNMTModel"] @@ -270,8 +286,53 @@ def _build_vocab(self): tensor_model_parallel_size=self._cfg.get('tensor_model_parallel_size', 1), ) - def eval_step(self, batch, batch_idx, dataloader_idx=0): + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + batch = next(dataloader_iter) + if isinstance(batch, dict): + # convert to list if not already converted. + batch = self._process_batch(batch) + + # Get seq length of batch + encoder_seq_length = batch[0].size(1) + decoder_seq_length = batch[1].size(1) + + tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size] + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=data_iter, + model=[self.enc_dec_model], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + tensor_shape=tensor_shape, + decoder_seq_length=decoder_seq_length, + dtype=self.autocast_dtype, + grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None, + sequence_parallel=self.cfg.get('sequence_parallel', False), + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # we're not on the last pipeline stage so no losses + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def eval_step(self, dataloader_iter, batch_idx, dataloader_idx=0): # Need to squeze dim 0 for old NMT datasets since things are pre-batched and we ask the dataloader for batch size 1. + batch = next(dataloader_iter) batch = [x.squeeze(dim=0) if x.ndim == 3 else x for x in batch] batch = self.process_global_batch_for_text_translation_datasets(batch) @@ -285,7 +346,8 @@ def eval_step(self, batch, batch_idx, dataloader_idx=0): data_parallel_size=parallel_state.get_data_parallel_world_size(), ) # This returns the averaged loss across data-parallel groups. - reduced_loss = super().validation_step(batch, batch_idx, dataloader_idx) + reduced_loss = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, True) + tokens_enc, labels, enc_mask = batch['text_enc'], batch['labels'], batch['enc_mask'] predicted_tokens_ids, _ = self.decode( @@ -344,12 +406,12 @@ def postprocess_outputs(self, outputs, tokenizer, processor): return results - def validation_step(self, batch, batch_idx, dataloader_idx=0): + def validation_step(self, dataloader_iter, batch_idx, dataloader_idx=0): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ - return self.eval_step(batch, batch_idx, dataloader_idx) + return self.eval_step(dataloader_iter, batch_idx, dataloader_idx) def _setup_eval_dataloader_from_config(self, cfg: DictConfig, dataset): rank = parallel_state.get_data_parallel_rank() @@ -368,6 +430,7 @@ def _setup_eval_dataloader_from_config(self, cfg: DictConfig, dataset): pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), shuffle=False, + persistent_workers=True, ) ) @@ -464,18 +527,18 @@ def eval_epoch_end(self, outputs, mode): if self.multilingual: self._log_multilingual_bleu_and_loss(dataloader_idx, bleu_score, averaged_loss, mode) else: - self.log(f'{mode}_sacreBLEU', bleu_score) - self.log(f'{mode}_loss', averaged_loss, prog_bar=True) + self.log(f'{mode}_sacreBLEU', bleu_score, batch_size=1) + self.log(f'{mode}_loss', averaged_loss, prog_bar=True, batch_size=1) else: if self.multilingual: self._log_multilingual_bleu_and_loss(dataloader_idx, bleu_score, averaged_loss, mode) else: - self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}', bleu_score) - self.log(f'{mode}_loss_dl_index_{dataloader_idx}', averaged_loss, prog_bar=False) + self.log(f'{mode}_sacreBLEU_dl_index_{dataloader_idx}', bleu_score, batch_size=1) + self.log(f'{mode}_loss_dl_index_{dataloader_idx}', averaged_loss, prog_bar=False, batch_size=1) if len(loss_list) > 1: - self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True) - self.log(f"{mode}_sacreBLEU_avg", np.mean(bleu_score_list)) + self.log(f"{mode}_loss_avg", np.mean(loss_list), sync_dist=True, batch_size=1) + self.log(f"{mode}_sacreBLEU_avg", np.mean(bleu_score_list), batch_size=1) def _log_multilingual_bleu_and_loss(self, dataloader_idx, bleu_score, loss, mode): """ @@ -487,8 +550,8 @@ def _log_multilingual_bleu_and_loss(self, dataloader_idx, bleu_score, loss, mode else: translation_lang_string = f'{self.src_language}-{self.tgt_language[dataloader_idx]}' - self.log(f'{mode}_sacreBLEU_{translation_lang_string}', bleu_score, sync_dist=True) - self.log(f'{mode}_loss_{translation_lang_string}', loss, sync_dist=True) + self.log(f'{mode}_sacreBLEU_{translation_lang_string}', bleu_score, sync_dist=True, batch_size=1) + self.log(f'{mode}_loss_{translation_lang_string}', loss, sync_dist=True, batch_size=1) def setup_validation_data(self, val_data_config: Optional[DictConfig]): if hasattr(self, '_validation_ds'): @@ -529,6 +592,7 @@ def _setup_megatron_dataloader_from_config(self, cfg, dataset, consumed_samples) collate_fn=collate_fn, num_workers=cfg.num_workers, pin_memory=cfg.pin_memory, + persistent_workers=True, ) def process_global_batch_for_text_translation_datasets(self, batch): diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index fcf416a4d7fb..e6480362bc85 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -29,7 +29,6 @@ try: from apex.normalization.fused_layer_norm import MixedFusedLayerNorm - from apex.transformer.tensor_parallel import ColumnParallelLinear, RowParallelLinear HAVE_APEX = True @@ -37,6 +36,15 @@ HAVE_APEX = False +try: + from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + class AdapterName(str, enum.Enum): """ @@ -102,6 +110,9 @@ def __init__( if not HAVE_APEX: logging.info("Apex is required to use ParallelLinearAdapters.") raise RuntimeError("ParallelLinearAdapter can not run without Apex.") + if not HAVE_MEGATRON_CORE: + logging.info("Megatron-core is required to use ParallelLinearAdapters.") + raise RuntimeError("ParallelLinearAdapter can not run without Megatron-core.") self.activation = activation_registry[activation]() self.norm_position = norm_position diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 5cbf1776a4b2..5c2267a25e44 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -26,7 +26,6 @@ from nemo.core import adapter_mixins try: - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType, AttnType from apex.transformer.utils import divide as safe_divide @@ -39,6 +38,16 @@ # fake missing classes with None attributes ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() + +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + """ We use the following notation throughout this file: h: hidden size n: number of attention heads @@ -116,8 +125,8 @@ def __init__( self.num_attention_heads_per_partition * parallel_state.get_tensor_model_parallel_rank() ) - no_async_tensor_model_parallel_allreduce = ( - parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel + async_tensor_model_parallel_allreduce = ( + parallel_state.get_tensor_model_parallel_world_size() > 1 and not sequence_parallel ) # Strided linear layer. @@ -130,7 +139,7 @@ def __init__( use_cpu_initialization=use_cpu_initialization, bias=bias, sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ) else: @@ -142,7 +151,7 @@ def __init__( init_method=init_method, bias=bias, sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ) @@ -153,7 +162,7 @@ def __init__( init_method=init_method, bias=bias, sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ) diff --git a/nemo/collections/nlp/modules/common/megatron/build_model.py b/nemo/collections/nlp/modules/common/megatron/build_model.py new file mode 100644 index 000000000000..4c7790773d5b --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/build_model.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional + +import torch + +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults +from nemo.utils import logging + +try: + from megatron.core import parallel_state + from megatron.core.enums import ModelType + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + ModelType = ApexGuardDefaults() + + HAVE_MEGATRON_CORE = False + +try: + from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +# Apex's `build model' refactored to call Megatron-Core classes +def build_model( + model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], + wrap_with_ddp: bool = True, + virtual_pipeline_model_parallel_size: Optional[int] = None, + model_type: ModelType = ModelType.encoder_or_decoder, + on_cpu: bool = False, + *args: Any, + **kwargs: Any, +) -> List[torch.nn.Module]: + """Build the model satisfying pipeline model parallel requirements. + This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to + `model_provider_func`. + Args: + model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. + wrap_with_ddp: If :obj:`True`, wrap the instantiated model + with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. + virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. + model_type: + *args: arguments for model provider func + **kwargs: Keyword arguments for model provider func + Returns: + a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, + the list has multiple models, otherwise one. + """ + if model_type is None: + model_type = ModelType.encoder_or_decoder + + if ( + parallel_state.get_pipeline_model_parallel_world_size() > 1 + and virtual_pipeline_model_parallel_size is not None + ): + model = [] + for i in range(virtual_pipeline_model_parallel_size): + cur_args = args + cur_kwargs = kwargs + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + cur_kwargs.update( + {"pre_process": pre_process, "post_process": post_process,} + ) + this_model = model_provider_func(*cur_args, **cur_kwargs) + model.append(this_model) + else: + cur_args = args + cur_kwargs = kwargs + if model_type == ModelType.encoder_or_decoder: + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + cur_kwargs.update( + {"pre_process": pre_process, "post_process": post_process,} + ) + model = model_provider_func(*cur_args, **cur_kwargs) + elif model_type == ModelType.encoder_and_decoder: + pre_process = parallel_state.is_pipeline_first_stage() + post_process = parallel_state.is_pipeline_last_stage() + # `add_encoder` & `add_decoder` logic. + add_encoder, add_decoder = True, True + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + split_rank = parallel_state.get_pipeline_model_parallel_split_rank() + if split_rank is None: + raise RuntimeError("Split rank needs to be specified for model with both encoder and decoder.") + rank = parallel_state.get_pipeline_model_parallel_rank() + world_size = parallel_state.get_pipeline_model_parallel_world_size() + pre_process = rank == 0 or rank == split_rank + post_process = rank == (split_rank - 1) or rank == (world_size - 1) + add_encoder = parallel_state.is_pipeline_stage_before_split() + add_decoder = parallel_state.is_pipeline_stage_after_split() + cur_kwargs.update( + { + "pre_process": pre_process, + "post_process": post_process, + "add_encoder": add_encoder, + "add_decoder": add_decoder, + } + ) + model = model_provider_func(*cur_args, **cur_kwargs) + else: + raise ValueError(f"Unrecognized ModelType '{model_type}'") + + model.model_type = model_type + + if not isinstance(model, list): + model = [model] + + # Set tensor model parallel attributes if not set. + # Only parameters that are already tensor model parallel have these + # attributes set for them. We should make sure the default attributes + # are set for all params so the optimizer can use them. + for model_module in model: + for param in model_module.parameters(): + set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # Print number of parameters. + if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0: + msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( + parallel_state.get_tensor_model_parallel_rank(), + parallel_state.get_pipeline_model_parallel_rank(), + _calc_number_of_params(model), + ) + logging.info(msg) + + # GPU allocation. + if not on_cpu: + for model_module in model: + model_module.cuda(torch.cuda.current_device()) + + if wrap_with_ddp: + i = torch.cuda.current_device() + model = [ + torch.nn.parallel.distributed.DistributedDataParallel( + model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), + ) + for model_module in model + ] + return model + + +def _calc_number_of_params(model: List[torch.nn.Module]) -> int: + assert isinstance(model, list) + return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model]) diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py index 00a8244522ed..68a97485edf6 100644 --- a/nemo/collections/nlp/modules/common/megatron/clip_grads.py +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -24,22 +24,34 @@ try: import amp_C from apex.multi_tensor_apply import multi_tensor_applier - from apex.transformer import parallel_state - from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False HAVE_APEX_DISTRIBUTED_ADAM = False + if HAVE_APEX: try: from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam HAVE_APEX_DISTRIBUTED_ADAM = True + except (ImportError, ModuleNotFoundError): pass +try: + from megatron.core import parallel_state + from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters whose gradients diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 902d571ae170..0ab2ae79bed1 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -27,17 +27,27 @@ ) try: - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False # fake missing classes with None attributes AttnMaskType = ApexGuardDefaults() LayerType = ApexGuardDefaults() +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + def get_language_model( hidden_size, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index e7e3aeec31da..65a788de438c 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -20,10 +20,18 @@ from nemo.utils import AppState, logging try: - from apex.transformer import tensor_parallel from apex.transformer.log_util import set_logging_level from apex.transformer.microbatches import ConstantNumMicroBatches - from apex.transformer.parallel_state import ( + from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +try: + from megatron.core import tensor_parallel + from megatron.core.parallel_state import ( get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank, set_pipeline_model_parallel_split_rank, @@ -32,14 +40,12 @@ set_tensor_model_parallel_world_size, set_virtual_pipeline_model_parallel_rank, ) - from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False try: - # TODO: remove when apex is updated from apex.transformer.parallel_state import set_virtual_pipeline_model_parallel_world_size HAVE_INTERLEAVED = True @@ -65,7 +71,7 @@ def initialize_model_parallel_for_nemo( ): if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED: - raise ValueError("set_virtual_pipeline_model_parallel_world_size is needed in Apex for interleaved.") + raise ValueError("set_virtual_pipeline_model_parallel_world_size is needed in megatron-core for interleaved.") # updating NeMo globals app_state = AppState() @@ -171,7 +177,7 @@ def fake_initialize_model_parallel( """ Fake initialize model data parallel groups so that we can instantiate model parallel models before DDP is initialized. This is needed because PTL execution flow is init model, init trainer -> call trainer.fit(model). DDP is initialized during .fit. - This function is taken from apex.transformer.parallel_state and modified so that the distributed groups are not created. + This function is taken from megatron.core.parallel_state and modified so that the distributed groups are not created. We only need the tensor parallel and pipeline parallel ranks to instantiate the model. Arguments: diff --git a/nemo/collections/nlp/modules/common/megatron/mlp.py b/nemo/collections/nlp/modules/common/megatron/mlp.py index f3ddd0f89b44..43e30784c63a 100644 --- a/nemo/collections/nlp/modules/common/megatron/mlp.py +++ b/nemo/collections/nlp/modules/common/megatron/mlp.py @@ -32,7 +32,6 @@ try: from apex.normalization import MixedFusedRMSNorm from apex.transformer import parallel_state, tensor_parallel - from apex.transformer.parallel_state import get_tensor_model_parallel_world_size HAVE_APEX = True @@ -44,6 +43,17 @@ ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.parallel_state import get_tensor_model_parallel_world_size + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + class ParallelMLP(MegatronModule, adapter_mixins.AdapterModuleMixin): """MLP. @@ -100,8 +110,8 @@ def __init__( ) self.fast_glu_activation = activation in ['fast-geglu', 'fast-swiglu', 'fast-reglu'] - no_async_tensor_model_parallel_allreduce = ( - parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel + async_tensor_model_parallel_allreduce = ( + parallel_state.get_tensor_model_parallel_world_size() > 1 and not sequence_parallel ) # Project to 4h. self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( @@ -115,7 +125,7 @@ def __init__( use_cpu_initialization=use_cpu_initialization, bias=bias, sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ) @@ -131,7 +141,7 @@ def __init__( use_cpu_initialization=use_cpu_initialization, bias=bias, sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, + async_tensor_model_parallel_allreduce=async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ) diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index 89d5dc1f7443..52464b819c2f 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -21,13 +21,13 @@ from nemo.utils import logging try: - from apex.transformer import parallel_state, tensor_parallel + from megatron.core import parallel_state, tensor_parallel - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) @@ -44,9 +44,9 @@ class MegatronModule(torch.nn.Module): for pipelining.""" def __init__(self, share_token_embeddings=True): - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) super(MegatronModule, self).__init__() self.share_token_embeddings = share_token_embeddings @@ -255,9 +255,9 @@ def float_conversion(val): class Float16Module(MegatronModule): def __init__(self, module, precision): - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) super().__init__() self.precision = precision diff --git a/nemo/collections/nlp/modules/common/megatron/mup/layer.py b/nemo/collections/nlp/modules/common/megatron/mup/layer.py index ce969e455d6e..71bff3b057e0 100644 --- a/nemo/collections/nlp/modules/common/megatron/mup/layer.py +++ b/nemo/collections/nlp/modules/common/megatron/mup/layer.py @@ -44,11 +44,13 @@ from nemo.utils import logging try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False class MuReadout(MegatronModule): diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py index 138361b9c958..0b164a80e0e4 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_token_level_encoder_decoder.py @@ -28,7 +28,6 @@ ) try: - from apex.transformer import tensor_parallel from apex.transformer.enums import ModelType HAVE_APEX = True @@ -38,6 +37,14 @@ AttnMaskType = ApexGuardDefaults() ModelType = ApexGuardDefaults() +try: + from megatron.core import tensor_parallel + + HAVE_MEGATRON_CORE = True +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = True + __all__ = ["MegatronRetrievalTokenLevelEncoderDecoderModule"] diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index acb61d421e03..a02fb5300912 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -37,16 +37,27 @@ from nemo.collections.nlp.modules.common.megatron.vocab_parallel_cross_entropy import vocab_parallel_cross_entropy try: - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType, ModelType HAVE_APEX = True + except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + # fake missing classes with None attributes AttnMaskType = ApexGuardDefaults() ModelType = ApexGuardDefaults() + HAVE_APEX = False + +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule"] diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index b3a04f09b186..85d055f70e37 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -43,7 +43,6 @@ try: from apex.normalization import MixedFusedRMSNorm - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType, AttnType, ModelType HAVE_APEX = True @@ -55,6 +54,15 @@ # fake missing classes with None attributes ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + try: from transformer_engine.common import recipe from transformer_engine.pytorch import TransformerLayer, fp8_autocast diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index dc64c181941b..696f4c257822 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -13,25 +13,35 @@ # limitations under the License. """Utilities for models.""" - +import itertools import math -from typing import Dict, List, Tuple, Union +from typing import Dict, Iterator, List, Tuple, Union import torch try: from apex.normalization import MixedFusedRMSNorm from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType from apex.transformer.layers.layer_norm import FastLayerNorm from apex.transformer.pipeline_parallel.schedules.common import listify_model - from apex.transformer.tensor_parallel.layers import linear_with_grad_accumulation_and_async_allreduce HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.tensor_parallel.layers import linear_with_grad_accumulation_and_async_allreduce + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + class ApexGuardDefaults(object): """ @@ -356,3 +366,10 @@ def get_all_params_for_weight_decay_optimization( ] return ({'params': weight_decay_params},) + + +def get_iterator_k_split(batch: List[torch.Tensor], microbatches: int) -> Iterator: + assert batch[0].shape[0] % microbatches == 0, "Issue with batch size configuration!" + split_batch = [torch.tensor_split(item, microbatches, dim=0) for item in batch] + microbatches = [[elem[i] for elem in split_batch] for i in range(microbatches)] + return itertools.chain(microbatches) diff --git a/nemo/collections/nlp/modules/common/megatron/vocab_parallel_cross_entropy.py b/nemo/collections/nlp/modules/common/megatron/vocab_parallel_cross_entropy.py index a1b09ea85354..4f5bd8d572bb 100644 --- a/nemo/collections/nlp/modules/common/megatron/vocab_parallel_cross_entropy.py +++ b/nemo/collections/nlp/modules/common/megatron/vocab_parallel_cross_entropy.py @@ -14,16 +14,18 @@ import torch try: - from apex.transformer.parallel_state import ( + from megatron.core.parallel_state import ( get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) - from apex.transformer.tensor_parallel.utils import VocabUtility + from megatron.core.tensor_parallel.utils import VocabUtility + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False __all__ = ["vocab_parallel_cross_entropy"] diff --git a/nemo/collections/nlp/modules/common/prompt_encoder.py b/nemo/collections/nlp/modules/common/prompt_encoder.py index ebdc9452a673..282ad053bc86 100644 --- a/nemo/collections/nlp/modules/common/prompt_encoder.py +++ b/nemo/collections/nlp/modules/common/prompt_encoder.py @@ -26,15 +26,13 @@ from nemo.core.neural_types import ChannelType, NeuralType try: - from apex.transformer import parallel_state, tensor_parallel + from megatron.core import tensor_parallel - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - # fake missing classes with None attributes - ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() + HAVE_MEGATRON_CORE = False __all__ = ["PromptEncoder", "PromptEncoderType"] @@ -156,9 +154,6 @@ def __init__( sequence_parallel = False gradient_accumulation_fusion = False - no_async_tensor_model_parallel_allreduce = ( - parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel - ) self.first = tensor_parallel.ColumnParallelLinear( self.output_size, self.hidden_size, @@ -168,7 +163,6 @@ def __init__( use_cpu_initialization=False, bias=True, sequence_parallel_enabled=sequence_parallel, - no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, gradient_accumulation_fusion=gradient_accumulation_fusion, ) self.second = tensor_parallel.RowParallelLinear( diff --git a/nemo/collections/nlp/modules/common/prompt_table.py b/nemo/collections/nlp/modules/common/prompt_table.py index a4ad43bd237e..4cb2262837ee 100644 --- a/nemo/collections/nlp/modules/common/prompt_table.py +++ b/nemo/collections/nlp/modules/common/prompt_table.py @@ -13,20 +13,6 @@ # limitations under the License. import enum -import math - -import torch -import torch.nn as nn -import torch.nn.init as init - -from nemo.core.classes import Exportable, NeuralModule - -try: - from apex.transformer import tensor_parallel - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False __all__ = ['VirtualPromptSource', 'VirtualPromptStyle', 'VirtualPromptPlaceholderToken'] diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index fa94fc9c3910..a2e7f351ae09 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -21,15 +21,23 @@ from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids try: - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import forward_backward_no_pipelining - from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, - ) + from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True + except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + # the text representation of eos_id, it applies for all tokenizers END_OF_SEQ = '<|endoftext|>' @@ -44,27 +52,18 @@ def __init__(self, model): self.model.eval() def forward_step(self, batch, tensor_shape): + fwd_bwd_function = get_forward_backward_func() + + output_tensor = fwd_bwd_function( + forward_step_func=self.model.get_forward_output_only_func(), + data_iterator=batch, + model=[self.forward_model], + num_microbatches=get_num_microbatches(), + forward_only=True, + tensor_shape=tensor_shape, + dtype=self.model.autocast_dtype, + ) - if self.model.cfg.get('pipeline_model_parallel_size', 1) > 1: - output_tensor = forward_backward_pipelining_without_interleaving( - forward_step_func=self.model.get_forward_output_only_func(), - batch=batch, - model=self.forward_model, - forward_only=True, - tensor_shape=tensor_shape, - dtype=self.model.autocast_dtype, - sync_batch_comm=self.model.cfg.get('sync_batch_comm', False), - ) - else: - output_tensor = forward_backward_no_pipelining( - forward_step_func=self.model.get_forward_output_only_func(), - batch=batch, - model=self.forward_model, - forward_only=True, - tensor_shape=tensor_shape, - dtype=self.model.autocast_dtype, - sync_batch_comm=self.model.cfg.get('sync_batch_comm', False), - ) return output_tensor def tokenize_batch(self, sentences, max_len, add_BOS): diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index cf0aa514aaf4..b39ac406d4a4 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -30,13 +30,23 @@ from nemo.utils import AppState try: - from apex.transformer import parallel_state, tensor_parallel from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator HAVE_APEX = True + except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +try: + from megatron.core import parallel_state, tensor_parallel + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + __all__ = [ "get_default_sampling_params", "get_default_length_params", diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index e2e1672fd789..1a7318c4e7ce 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -26,16 +26,16 @@ from nemo.collections.common.tokenizers.youtokentome_tokenizer import YouTokenToMeTokenizer from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list from nemo.collections.nlp.modules.common.lm_utils import get_pretrained_lm_models_list -from nemo.collections.nlp.parts.nlp_overrides import HAVE_APEX +from nemo.collections.nlp.parts.nlp_overrides import HAVE_MEGATRON_CORE from nemo.utils import logging try: from nemo.collections.nlp.modules.common.megatron.megatron_utils import get_megatron_tokenizer - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False __all__ = ['get_tokenizer', 'get_tokenizer_list'] @@ -101,9 +101,9 @@ def get_tokenizer( special_tokens_dict = special_tokens if 'megatron' in tokenizer_name: - if not HAVE_APEX: + if not HAVE_MEGATRON_CORE: raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + "Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) if vocab_file is None: vocab_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_vocab_file( diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index c3668b559cb1..3b11eb838a2f 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -42,7 +42,6 @@ from nemo.utils.model_utils import inject_model_parallel_rank try: - from apex.transformer import parallel_state from apex.transformer.enums import ModelType from apex.transformer.pipeline_parallel.schedules.common import _calc_number_of_params from apex.transformer.pipeline_parallel.utils import get_num_microbatches @@ -54,6 +53,15 @@ HAVE_APEX = False +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE" @@ -78,6 +86,11 @@ def __init__( raise ImportError( "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) + + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) super().__init__(parallel_devices, cluster_environment, checkpoint_io, **kwargs) self.no_ddp_communication_hook = no_ddp_communication_hook @@ -87,7 +100,7 @@ def setup_distributed(self, global_rank: int = None, world_size: int = None) -> super().setup_distributed() # init model parallel if needed - if parallel_state.is_unitialized(): + if not parallel_state.model_parallel_is_initialized(): app_state = AppState() if app_state.model_parallel_size is not None: @@ -151,11 +164,10 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None: parallel_state.destroy_model_parallel() if torch.distributed.is_initialized(): parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=app_state.tensor_model_parallel_size, - pipeline_model_parallel_size_=app_state.pipeline_model_parallel_size, - pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank, - virtual_pipeline_model_parallel_size_=app_state.virtual_pipeline_model_parallel_size, - use_fp8_=app_state.use_fp8, + tensor_model_parallel_size=app_state.tensor_model_parallel_size, + pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, ) # assert that fake tp and pp rank match after model parallel init @@ -240,6 +252,10 @@ def __init__(self) -> None: # raise ImportError( # "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." # ) + if not HAVE_MEGATRON_CORE: + logging.warning( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) super().__init__() def save_to(self, model, save_path: str): @@ -383,17 +399,24 @@ def restore_from( class PipelineMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """ Overrides PTL autocasting to not wrap training/val/test_step. - We do this because we have the Apex fwd/bwd functions in training_step. + We do this because we have the megatron-core fwd/bwd functions in training_step. This means .backward is being called in training_step so we do not want the whole step wrapped in autocast. - We instead wrap the fwd_output_and_loss_func that is passed to the Apex fwd/bwd functions. + We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions. """ def __init__( self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: super().__init__(precision, device, scaler=scaler) + dtype = None + if precision == 16: + dtype = torch.float16 + elif precision == 'bf16': + dtype = torch.bfloat16 + + torch.set_autocast_gpu_dtype(dtype) @contextmanager def forward_context(self) -> Generator[None, None, None]: @@ -428,6 +451,9 @@ def __init__( self.hysteresis = hysteresis self._hysteresis_tracker = self.hysteresis + def __call__(self, outputs): + return self.scale(outputs) + def _unscale_grads_(self, optimizer, *args): if getattr(optimizer, "_custom_amp_unscale_grads", False): return optimizer.unscale_grads(*args) @@ -594,6 +620,13 @@ def __init__( self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: super().__init__(precision, device, scaler) + dtype = None + if precision == 16: + dtype = torch.float16 + elif precision == 'bf16': + dtype = torch.bfloat16 + + torch.set_autocast_gpu_dtype(dtype) def optimizer_step( self, @@ -651,6 +684,8 @@ def __init__(self, prefetch_batches: int = 0, store_on_device: bool = False) -> if not HAVE_APEX: logging.warning("Apex was not found. Using model parallel or megatron models will error out.") + if not HAVE_MEGATRON_CORE: + logging.warning("Megatron-core was not found. Using model parallel or megatron models will error out..") super().__init__(prefetch_batches=prefetch_batches, store_on_device=store_on_device) @@ -664,130 +699,3 @@ def _fetch_next_batch(self, iterator: Iterator) -> None: assert isinstance(dataloader, Sized) # `_has_len` is True self.done = self.fetched >= len(dataloader) self.on_fetch_end(batch, start_output) - - -def build_model_cpu( - model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], - wrap_with_ddp: bool = True, - virtual_pipeline_model_parallel_size: Optional[int] = None, - model_type: 'ModelType' = None, # Default of ModelType.encoder_or_decoder - *args: Any, - **kwargs: Any, -) -> List[torch.nn.Module]: - """Build the model satisfying pipeline model parallel requirements on CPU. - - NOTE: This function is a duplicate of the Apex `build_model` function with the only difference - being that it removes an explicit forced cast of the model onto the GPU. This is necessary - because in certain cases the model is simply too large to fit on the GPU, and the user - should be responsible for casting the model to the GPU if they wish to do so. - - Specifically, this function is used when constructing a dummy model containing the union of parameters - from all TP and PP ranks, which is then sharded across a new TP PP configuration. - - This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to - `model_provider_func`. - - Args: - model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. - wrap_with_ddp: If :obj:`True`, wrap the instantiated model - with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. - virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. - model_type: - *args: arguments for model provider func - **kwargs: Keyword arguments for model provider func - - Returns: - a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, - the list has multiple models, otherwise one. - """ - if not HAVE_APEX: - raise ValueError("Apex is required for pipeline parallelism.") - - if model_type is None: - model_type = ModelType.encoder_or_decoder - - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and virtual_pipeline_model_parallel_size is not None - ): - model = [] - for i in range(virtual_pipeline_model_parallel_size): - cur_args = args - cur_kwargs = kwargs - parallel_state.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - cur_kwargs.update( - {"pre_process": pre_process, "post_process": post_process,} - ) - this_model = model_provider_func(*cur_args, **cur_kwargs) - model.append(this_model) - else: - cur_args = args - cur_kwargs = kwargs - model = None - if model_type == ModelType.encoder_or_decoder: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - cur_kwargs.update( - {"pre_process": pre_process, "post_process": post_process,} - ) - model = model_provider_func(*cur_args, **cur_kwargs) - elif model_type == ModelType.encoder_and_decoder: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - # `add_encoder` & `add_decoder` logic. - add_encoder, add_decoder = True, True - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - split_rank = parallel_state.get_pipeline_model_parallel_split_rank() - if split_rank is None: - raise RuntimeError("Split rank needs to be specified for model with both encoder and decoder.") - rank = parallel_state.get_pipeline_model_parallel_rank() - world_size = parallel_state.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == split_rank - post_process = rank == (split_rank - 1) or rank == (world_size - 1) - add_encoder = parallel_state.is_pipeline_stage_before_split() - add_decoder = parallel_state.is_pipeline_stage_after_split() - cur_kwargs.update( - { - "pre_process": pre_process, - "post_process": post_process, - "add_encoder": add_encoder, - "add_decoder": add_decoder, - } - ) - model = model_provider_func(*cur_args, **cur_kwargs) - - if model is not None: - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - # Print number of parameters. - if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0: - msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_pipeline_model_parallel_rank(), - _calc_number_of_params(model), - ) - print(msg, flush=True) - - if wrap_with_ddp: - i = torch.cuda.current_device() - model = [ - torch.nn.parallel.distributed.DistributedDataParallel( - model_module, device_ids=[i], output_device=i, process_group=parallel_state.get_data_parallel_group(), - ) - for model_module in model - ] - return model diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index 2a9226ddb417..1f2ce90f3ff7 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -21,7 +21,7 @@ _coalescing_manager, _disable_pre_forward_hook, ) -from apex.transformer import parallel_state +from megatron.core import parallel_state def _str_to_dtype(dtype): diff --git a/nemo/core/optim/optimizer_with_main_params.py b/nemo/core/optim/optimizer_with_main_params.py index ab8fc68c938e..cab5e84fda2f 100644 --- a/nemo/core/optim/optimizer_with_main_params.py +++ b/nemo/core/optim/optimizer_with_main_params.py @@ -21,8 +21,6 @@ try: import amp_C from apex.multi_tensor_apply import multi_tensor_applier - from apex.transformer.parallel_state import get_data_parallel_group, get_data_parallel_world_size - from apex.transformer.tensor_parallel import copy_tensor_model_parallel_attributes HAVE_APEX = True @@ -30,6 +28,16 @@ HAVE_APEX = False +try: + from megatron.core.parallel_state import get_data_parallel_group, get_data_parallel_world_size + from megatron.core.tensor_parallel import copy_tensor_model_parallel_attributes + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + def _zero_grad_group_helper(group, set_to_none): """Zero out the gradient for a group of parameters. @@ -71,6 +79,11 @@ def __init__(self, numel, chunk_size_mb): "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + self.numel = numel self.data = torch.zeros(self.numel, dtype=torch.float, device=torch.cuda.current_device(), requires_grad=False) @@ -169,6 +182,11 @@ def __init__( "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + self.optimizer = optimizer assert self.optimizer, 'no optimizer is provided.' if contiguous_grad_bucket: diff --git a/nemo/utils/distributed.py b/nemo/utils/distributed.py index 709d4180426e..b0d24de3e5b4 100644 --- a/nemo/utils/distributed.py +++ b/nemo/utils/distributed.py @@ -19,11 +19,11 @@ from nemo.utils import logging try: - from apex.transformer import parallel_state + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False def initialize_distributed(args, backend='nccl'): diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index fc2b4ad6f27f..cc0ce744a9a6 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -231,7 +231,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.normalization import MixedFusedRMSNorm from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear + from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ diff --git a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py index 9700c32edf1d..61cbbc1ae682 100644 --- a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py @@ -29,11 +29,13 @@ from nemo.utils.model_utils import inject_model_parallel_rank try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False """ This is the script to convert the p-tuning PTL checkpoint file to nemo file for evaluation. diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py index 6956028c0ab7..4a373dcaf278 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py @@ -25,11 +25,11 @@ from nemo.core.config import hydra_runner try: - from apex.transformer import parallel_state + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False """ This is the script to launch RETRO Model text generation server. diff --git a/tests/collections/nlp/test_indexed_retrieval_dataset.py b/tests/collections/nlp/test_indexed_retrieval_dataset.py index ab2b18d390a5..e35c3ab36840 100644 --- a/tests/collections/nlp/test_indexed_retrieval_dataset.py +++ b/tests/collections/nlp/test_indexed_retrieval_dataset.py @@ -31,15 +31,17 @@ from nemo.collections.nlp.data.language_modeling.megatron.retro_dataset import RETRODataset try: - from apex.transformer import parallel_state + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + + HAVE_MEGATRON_CORE = False @pytest.mark.run_only_on('GPU') -@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") +@pytest.mark.skipif(not HAVE_MEGATRON_CORE, reason="megatron-core is not installed") class TestRetrievalIndexFiles: @classmethod def setup_class(cls): @@ -439,7 +441,7 @@ def test_knn_index(self): os.remove(merged_file) @pytest.mark.unit - @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") + @pytest.mark.skipif(not HAVE_MEGATRON_CORE, reason="megatron-core is not installed") def test_retro_dataset(self): chunk_size = 64 @@ -593,7 +595,7 @@ class Tokenizer: os.remove(shuffle_idx_filename) @pytest.mark.unit - @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") + @pytest.mark.skipif(not HAVE_MEGATRON_CORE, reason="megatron-core is not installed") def test_retro_dataset_stride32(self): chunk_size = 64 pad_id = 0 @@ -747,7 +749,7 @@ class Tokenizer: os.remove(shuffle_idx_filename) @pytest.mark.unit - @pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") + @pytest.mark.skipif(not HAVE_MEGATRON_CORE, reason="megatron-core is not installed") def test_dedup(self): total = 1000 id_start = np.array([0, 100, 200, 300, 500, 900]) diff --git a/tests/collections/nlp/test_retrieval_module.py b/tests/collections/nlp/test_retrieval_module.py index 175fdeaadab6..3a2d46f4fed2 100644 --- a/tests/collections/nlp/test_retrieval_module.py +++ b/tests/collections/nlp/test_retrieval_module.py @@ -43,9 +43,18 @@ except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + from megatron.core.enums import ModelType + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + @pytest.mark.run_only_on('GPU') -@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") +@pytest.mark.skipif(not HAVE_APEX or not HAVE_MEGATRON_CORE, reason="apex or megatron-core is not installed") class TestRetrievalModule: @classmethod def setup_class(cls): diff --git a/tests/collections/nlp/test_retrieval_module_inference.py b/tests/collections/nlp/test_retrieval_module_inference.py index 4c655bc1f70f..16e7e556bd10 100644 --- a/tests/collections/nlp/test_retrieval_module_inference.py +++ b/tests/collections/nlp/test_retrieval_module_inference.py @@ -44,9 +44,18 @@ except (ImportError, ModuleNotFoundError): HAVE_APEX = False +try: + from megatron.core.enums import ModelType + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + @pytest.mark.run_only_on('GPU') -@pytest.mark.skipif(not HAVE_APEX, reason="apex is not installed") +@pytest.mark.skipif(not HAVE_APEX or not HAVE_MEGATRON_CORE, reason="apex or megatron-core is not installed") class TestRetrievalModuleInference: @classmethod def setup_class(cls): diff --git a/tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb b/tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb index 5933af782cc5..bd5e09e7c1f9 100644 --- a/tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb +++ b/tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb @@ -1,786 +1,786 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "b7a434f4", - "metadata": {}, - "outputs": [], - "source": [ - "BRANCH='main'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "developmental-gibraltar", - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", - "\n", - "Instructions for setting up Colab are as follows:\n", - "1. Open a new Python 3 notebook.\n", - "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", - "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", - "4. Run this cell to set up dependencies.\n", - "\"\"\"\n", - "# If you're using Google Colab and not running locally, run this cell\n", - "\n", - "# install NeMo\n", - "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]" - ] - }, - { - "cell_type": "markdown", - "id": "42daf8bf", - "metadata": {}, - "source": [ - "# Introduction\n", - "\n", - "In this notebook we demonstrate how to use p-tuning and prompt tuning within NeMo-Megatron. Both methods are parameter efficient alternatives to fine-tuning pretrained language models. Our NeMo implementation makes it possible to use one pretrained GPT model on many downstream tasks without needing to tune the model’s full set of parameters. It also allows for adding new tasks to your model without overwriting or disrupting previous tasks for which the model has already been p-tuned/prompt-tuned. Because the original model parameters are frozen and never altered by either method, p-tuning/prompt-tuning also avoid catastrophic forgetting issues often encountered when fine-tuning models.\n", - "\n", - "- Our prompt tuning implementation is based off Lester et. al’s EMNLP 2021 paper [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691)\n", - "\n", - "- Our p-tuning implementation is based off Liu et al's paper [GPT Understands, Too](https://arxiv.org/abs/2103.10385).\n", - "\n", - "- Command line usage examples and API documentation can be found in [our user docs](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/prompt_learning.html). \n", - "\n", - "\"Prompt\n", - "\n", - "Our continuous learning capability for combined p-tuning and prompt tuning with GPT style models is a NeMo specific extension of the author’s original work.\n", - "\n", - "# The Plan\n", - "\n", - "We are going to show you how to:\n", - " \n", - " 1. P-Tune/Prompt Tune a model on multiple tasks at the same time\n", - " 2. Add a new task to a model that has already been P-Tuned/Prompt Tuned previously\n", - " \n", - "We will first p-tune a GPT model on sentiment analysis, and intent and slot classification tasks. Then we will show how to add the squad question answering task to the same model we already p-tuned once.\n", - "\n", - "\n", - "# Technical Overview\n", - "Instead of selecting discrete text prompts in a manual or automated fashion, prompt tuning and p-tuning utilize virtual prompt embeddings that can be optimized via gradient decent. The only difference between prompt tuning and p-tuning within NeMo-Megatron is the architecture used to tune the soft prompt tokens during training.\n", - "\n", - "### Terminology\n", - "We will be using the terms `continuous`, `soft`, and `virtual` token interchangeably to refer to embeddings inserted into the model prompt that have no concrete mapping to strings or characters within the model’s vocabulary. These virtual token embeddings exist in contrast to the `discrete`, `hard`, or `real` tokens that do make up the model’s vocabulary. Virtual tokens are purely 1D vectors with dimensionality equal to that of each real token embedding, matching the `hidden_size` hyperparameter. In training and inference, continuous token embeddings are inserted among discrete token embeddings according to a template you provide in the model’s config. We will demonstrate how to do this below.\n", - "\n", - "When referring to p-tuning and prompt tuning together, we will be using the phrase prompt learning for simplicity.\n", - "\n", - "### Prompt-Tuning\n", - "In prompt-tuning a pretrained GPT model, soft prompt embeddings are initialized as a 2D matrix of size `total_virtual_tokens X hidden_size`. Each task the model is prompt-tuned to perform has its own 2D embedding matrix associated with it. Tasks do not share any parameters during training or inference. All GPT model parameters are frozen and only the embedding parameters for each task are updated during training.\n", - "\n", - "In prompt tuning you can specify how the embeddings are initialized for each task. You can either\n", - "\n", - "1. Initialize embedding parameters according to some random distribution\n", - "2. Initialize embedding parameters from existing vocabulary embeddings (recommended)\n", - "\n", - "If you choose to initialize virtual token embeddings from existing embedding weights, you can provide the string of words you want to use for initialization in the model’s config. This string will be tokenized and tiled or truncated to match the specified number of virtual tokens you would like to use (`total_virtual_tokens`). Vocab embeddings are copied and used to initialize the soft prompt embedding matrix for each task. The vocab embeddings themselves are not updated or changed during prompt tuning.\n", - "\n", - "\n", - "### P-Tuning\n", - "In p-tuning, an LSTM model is used to predict virtual token embeddings. We refer to this LSTM model as our `prompt_encoder`. LSTM parameters are randomly initialized at the start of p-tuning. All GPT model parameters are frozen, and only the LSTM weights are updated at each training step. LSTM parameters are shared between all tasks that are p-tuned at the same time, but the LSTM model outputs unique virtual token embeddings for each task. The virtual tokens predicted by the LSTM are inserted among the discrete token input in the exact same manner as with prompt-tuning. You still specify the number of virtual tokens you want to use by setting `total_virtual_tokens` and each virtual token embedding is still a 1D vector of size `hidden_size`.\n", - "\n", - "\n", - "\n", - "# The Best of Both\n", - "A single pretrained GPT model can use both p-tuning and prompt-tuning. While you must decide to use either p-tuning or prompt-tuning for each task you want your model to perform, you can p-tune your model on a set of tasks A, then prompt tune your same model on a different set of tasks B, then finally run inference on tasks from both A and B at the same time. During prompt-tuning or p-tuning, tasks tuned at the same time must use the same number of virtual tokens. During inference, tasks using differing amounts of virtual tokens can be run at the same time.\n", - "\n", - "Please see our [docs for more comparisons between prompt and p-tuning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/prompt_learning.html). \n", - "\n", - "With all that covered, let's get started!\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31c27562", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import wget" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0bfc7709", - "metadata": {}, - "source": [ - "# Tasks and Datasets\n", - "We will be using p-tuning to teach our GPT model to do **Question Answering**.\n", - "\n", - "We will be using the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) reading comprehension dataset, consisting of questions posed by crowd workers on a set of Wikipedia articles, where the answer to every question is a segment of text. More information on [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) can be found on their website or in their paper by Rajpurkar et. al \"[Know What You Don’t Know: Unanswerable Questions for SQuAD](https://arxiv.org/pdf/1806.03822.pdf)\"." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "e0b0072a", - "metadata": {}, - "source": [ - "# Data Preparation\n", - "\n", - "The prompt learning dataset loader accepts a list of json/dictionary objects or a list of json file names where each json file contains a collection of json objects. Each json object must include the field `taskname` which is a string identifier for the task the data example corresponds to. They should also include one or more fields corresponding to different sections of the discrete text prompt. The input data might look like:\n", - "\n", - "```\n", - "[\n", - " {\"taskname\": \"squad\", \"context\": [CONTEXT_PARAGRAPH_TEXT1], \"question\": [QUESTION_TEXT1], \"answer\": [ANSWER_TEXT1]},\n", - " {\"taskname\": \"squad\", \"context\": [CONTEXT_PARAGRAPH_TEXT2], \"question\": [QUESTION_TEXT2], \"answer\": [ANSWER_TEXT2]},\n", - "]\n", - "```\n", - "\n", - "These additional fields can be unlimited in number and will be used to help map different parts of the discrete text input to a prompt template that you define. We will show how this mapping works and how to construct your prompt template in the `Prompt Formatting` section. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0dbd41fd", - "metadata": {}, - "outputs": [], - "source": [ - "# You can replace DATA_DIR and NEMO_DIR with your own locations\n", - "DATA_DIR = \"data\"\n", - "NEMO_DIR = '.'\n", - "\n", - "os.makedirs(DATA_DIR, exist_ok=True)" - ] - }, - { - "cell_type": "markdown", - "id": "504a7b40", - "metadata": {}, - "source": [ - "\n", - "For each dataset we have preprocessing scripts pre-written in NeMo's example directory located in `examples/nlp`. Let's download those now. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e72a1dc1", - "metadata": {}, - "outputs": [], - "source": [ - "# download the preprocessing scripts from github for the purpose of this tutorial\n", - "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/scripts/dataset_processing/nlp/squad/prompt_learning_squad_preprocessing.py', NEMO_DIR)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "71813919", - "metadata": {}, - "source": [ - "Now let's down load and process the dataset." - ] - }, - { - "cell_type": "markdown", - "id": "816791de", - "metadata": {}, - "source": [ - "### SQuAD Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fa16d8ac", - "metadata": {}, - "outputs": [], - "source": [ - "SQUAD_DIR = os.path.join(DATA_DIR, \"SQuAD\")\n", - "os.makedirs(SQUAD_DIR, exist_ok=True)\n", - "\n", - "# Download the SQuAD dataset\n", - "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json\n", - "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json\n", - "!mv train-v1.1.json {SQUAD_DIR}\n", - "!mv dev-v1.1.json {SQUAD_DIR}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "64e3e25b", - "metadata": {}, - "outputs": [], - "source": [ - "# Preprocess squad data\n", - "!python $NEMO_DIR/prompt_learning_squad_preprocessing.py --data-dir {SQUAD_DIR}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b562d1de", - "metadata": {}, - "outputs": [], - "source": [ - "# What the squad dataset looks like after processing\n", - "!head -4 $SQUAD_DIR/squad_train.jsonl" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "a385d319", - "metadata": {}, - "source": [ - "We made a `.jsonl` file for each of the train, validation, and testing splits of the squad data. Every `.jsonl` file contains json objects with the fields `taskname`, `context`, `question`, and `answer`. The preprocessing script is called `prompt_learning_squad_preprocessing.py`. It should be in your `NEMO_DIR` and at `scripts/dataset_processing/nlp/squad/prompt_learning_squad_preprocessing.py` in the NeMo repo. \n", - "\n", - "The SQuAD dataset consists of various topics like `Beyoncé`, `IPod`, and `Symbiosis`. Each topic has several paragraphs associated with it, and each paragraph has several questions and answers related to it. When we separated the train/validation/test splits, we separated them on the topic level. For example, if the training set contains paragraphs and questions about the topic `Beyoncé`, neither the validation nor test sets will contain any questions on this topic. All questions about a certain topic are isolated to one split of the data. \n", - "\n", - "Like the Financial PhraseBank Dataset, we randomly selected 80% of the questions for training, 10% for validation, and 10% for test. This resulted in `69125` test examples, `8952` validation examples, and `8744` testing examples. The `answer` field was removed from test examples.\n", - "\n", - "Training on the full train split could take a lot of time, so we are going to clip the train split to 2k examples for the sake of this tutorial, and limit the validation dataset to 200 samples." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f1473ba", - "metadata": {}, - "outputs": [], - "source": [ - "! head -2000 $SQUAD_DIR/squad_train.jsonl > $SQUAD_DIR/squad_short_train.jsonl\n", - "! head -200 $SQUAD_DIR/squad_val.jsonl > $SQUAD_DIR/squad_short_val.jsonl\n" - ] - }, - { - "cell_type": "markdown", - "id": "2e19c8dc", - "metadata": {}, - "source": [ - "# P-Tuning Model Config Setup\n", - "\n", - "Now we will begin setting up the config file used for prompt/p-tuning our GPT models! GPT Prompt learning within NeMo uses a class called `MegatronGPTPromptLearningModel` which has its own config file. We will start by loading an example prompt learning config file, then make changes to it to fit our tasks and training plans. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5749c387", - "metadata": {}, - "outputs": [], - "source": [ - "from omegaconf import OmegaConf\n", - "\n", - "CONFIG_DIR = os.path.join(NEMO_DIR, \"conf\")\n", - "os.makedirs(CONFIG_DIR, exist_ok=True)\n", - "\n", - "# Download the example config file\n", - "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/megatron_gpt_prompt_learning_config.yaml', CONFIG_DIR)\n", - "\n", - "# Load the example config file so we can start editing it\n", - "CONFIG_PATH = os.path.join(CONFIG_DIR, \"megatron_gpt_prompt_learning_config.yaml\")\n", - "config = OmegaConf.load(CONFIG_PATH)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "ce966bcf", - "metadata": {}, - "source": [ - "First let's set the datasets we've created in the config. We are going to start by p-tuning a GPT model on a small subset of the **Squad** task. We do this by setting the following config params below: " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6bb1590f", - "metadata": {}, - "outputs": [], - "source": [ - "config.model.data.train_ds = [f\"{SQUAD_DIR}/squad_short_train.jsonl\"]\n", - "config.model.data.validation_ds = [f\"{SQUAD_DIR}/squad_short_val.jsonl\"]" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "4e021b24", - "metadata": {}, - "source": [ - "### Prompt Formatting\n", - "Now that we have our dataset, lets define what we want the prompt to look like. \n", - "\n", - "The squad dataset json files contain fields named \"context\", \"question\" and \"answer\". The prompt formatting template allows us to arrange these fields and decide where to insert virtual prompts. We can add the `<|VIRTUAL_PROMPT_0|>` token anywere between the fields (although we recommend simply adding it in the leftmost position will be sufficient).\n", - "\n", - "For example, given a data jsonl file with examples like this: \n", - "\n", - "\n", - "**{\"taskname\": \"squad\", \"context\": \"Super Bowl 50 was an American football ga... numerals 50.\", \"question\": \"What does AFC stand for?\", \"answer\": \"American Football Conference\"}**. \n", - "\n", - "\n", - "We can create a prompt template set to `prompt_template = \"<|VIRTUAL_PROMPT_0|> Context: {context}\\n\\nquestion: {question}\\n\\nanswer: {answer}\"` other options are also possible, for example the `\\n` can be replaced with whitespace or the other of the context and question can be swapped. The answer however, should be at the end.\n", - "\n", - "Let's configure the prompt template for the task below:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f935b411", - "metadata": {}, - "outputs": [], - "source": [ - "config.model.task_templates = [\n", - " \n", - " {\n", - " \"taskname\": \"squad\",\n", - " \"prompt_template\": \"<|VIRTUAL_PROMPT_0|> Context: {context}\\n\\nQuestion: {question}\\n\\nAnswer:{answer}\",\n", - " \"total_virtual_tokens\": 15,\n", - " \"virtual_token_splits\": [15],\n", - " \"truncate_field\": \"context\",\n", - " \"answer_only_loss\": True,\n", - " \"answer_field\": \"answer\",\n", - " },\n", - " \n", - "]" - ] - }, - { - "cell_type": "markdown", - "id": "dcc438b5", - "metadata": {}, - "source": [ - "Note each `task_template` item has 5 fields. \n", - "\n", - "- **`prompt_template`** is a string showing the model where to place virtual tokens and how to map dataset json fields to where they belong in the model prompt. \n", - "\n", - "\n", - "- **`taskname`** refers to the same `taskname` in the dataset json objects. \n", - "\n", - "\n", - "- **`total_virtual_tokens`** specifies the total number of virtual tokens that will be inserted into the model prompt.\n", - "\n", - "\n", - "- **`virtual_token_splits`** specifies the number of virtual tokens that belong at each `<|VIRTUAL_PROMPT_#|>` marker. `virtual_token_splits` values should add up to `total_virtual_tokens`. The number of `virtual_token_splits` should match the number of `<|VIRTUAL_PROMPT_#|>` markers. \n", - "\n", - "\n", - "- **`truncate_field`** specifies which field in the data json to truncate if the length of the input exceeds the maximum sequence length of the model. If `truncate_field` is set to `None`, examples that are too long are simply dropped from the dataset.\n", - "\n", - "\n", - "- **`answer_only_loss`** Whether to limit loss calculation to only the answer portion of the prompt during tuning. `True` Strongly recommended for long prompts, but shorter prompts with single word answers seem to benefit from setting this to `False`. \n", - "\n", - "\n", - "- **`answer_field`** The field in the data json corresponding to the answer. The loss will only be calculated on this portion of the prompt if `answer_only_loss` is `True`. The answer field must be at the end of the prompt template.\n", - "\n", - "In the `task_templates` we set above, `squad` has a different number of virtual tokens than `sentiment` and `intent_and_slot`. This is because we will be p-tuning on `squad` after we p-tune on the other two tasks and **we do not need to use the same number of virtual tokens between sessions**. We also set the `truncate` field for squad because the context can sometimes be longer than the model's max sequence length, and we want that field to be truncated if the example is too long. Lastly, we set `answer_only_loss` to true for `squad` due to the longer prompt. We've found `answer_only_loss=True` to work significantly better for this task." - ] - }, - { - "cell_type": "markdown", - "id": "84579c7a", - "metadata": {}, - "source": [ - "### Setting New Tasks\n", - "After you p-tune your model this time, you can always go back and p-tune or prompt-tune your model on more tasks without over writing the virtual prompts who've trained this time. You can also use a different number of `total_virtual_tokens` between each training session as long as tasks p-tuned or prompt tuned at the same time have the same number of `total_virtual_tokens`. For this reason, when you p-tune on a new task, you need to tell your model which of your tasks are new and which ones already exist (and thus you don't want to tune them). \n", - "\n", - "You do this by setting the `new_tasks` and `existing_tasks` values in the config file. Because we are p-tuning a model with no existing tasks, you should set `existing_tasks=[]` and `new_tasks=[\"sentiment\", \"intent_and_slot\"]` as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "57a73e01", - "metadata": {}, - "outputs": [], - "source": [ - "config.model.existing_tasks = []\n", - "config.model.new_tasks = [\"squad\"]" - ] - }, - { - "cell_type": "markdown", - "id": "3b77e88c", - "metadata": {}, - "source": [ - "After p-tuning and/or prompt tuning is complete, you can run inference on all tasks at the same time, regardless of their `total_virtual_tokens` value." - ] - }, - { - "cell_type": "markdown", - "id": "a0d5017e", - "metadata": {}, - "source": [ - "### Setting The Pre-Trained GPT Model\n", - "We still need to set which GPT model we want to p-tune/prompt tune. Prompt learning methods work best with large GPT language models (5B or above), but the purposes of this tutorial, we are going to download a 345M parameter GPT model from NVIDIA NGC." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "48cdf868", - "metadata": {}, - "outputs": [], - "source": [ - "# Check what GPT .nemo models we have available on NGC\n", - "from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel\n", - "MegatronGPTModel.list_available_models()" - ] - }, - { - "cell_type": "markdown", - "id": "ede350ed", - "metadata": {}, - "source": [ - "If we wanted to use the GPT model class directly, we could instantiate a trainer then download the model by calling running \n", - "`gpt_model = MegatronGPTModel.from_pretrained(model_name=\"megatron_gpt_345m\", trainer=trainer).cuda()`. But we just need the `.nemo` file in our working NeMo directory in this tutorial, so we will download it using `wget`. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "364439a1", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Download the model from NGC\n", - "gpt_file_name = \"megatron_gpt_345m.nemo\"\n", - "!wget -nc --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/nemo/megatron_gpt_345m/versions/1/files/megatron_gpt_345m.nemo -O {NEMO_DIR}/{gpt_file_name}" - ] - }, - { - "cell_type": "markdown", - "id": "1d6a8a67", - "metadata": {}, - "source": [ - "Now that we have a `.nemo` GPT file to work with. We need to add its path in our prompt learning config. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2778a5fa", - "metadata": {}, - "outputs": [], - "source": [ - "# Set GPT model path on prompt learning config\n", - "config.model.language_model_path = gpt_file_name" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "943a9c83", - "metadata": {}, - "source": [ - "We can also set where we want the final prompt tuned model to be saved by setting `model.nemo_path`. By default the tuned prompt learning model will be saved in your current working directory to a `.nemo` file with the same name as your experiment (`config.name`). Let's change the save name to be `p_tuned_gpt.nemo`. **Your model path must end in `.nemo`.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a278cbdf", - "metadata": {}, - "outputs": [], - "source": [ - "config.exp_manager.checkpoint_callback_params.save_nemo_on_train_end= True\n", - "config.exp_manager.checkpoint_callback_params.always_save_nemo= True\n", - "config.exp_manager.checkpoint_callback_params.save_best_model= True" - ] - }, - { - "cell_type": "markdown", - "id": "378a73e7", - "metadata": {}, - "source": [ - "### Setting P-Tuning Specific Params\n", - "Within the config file, p-tuning and prompt-tuning each have a couple of hyperparameters specific to them. We first need to tell the model that we want to do p-tuning, not prompt-tuning. To do this, we set the **`model.virtual_prompt_style`** hyperparameter like this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68763763", - "metadata": {}, - "outputs": [], - "source": [ - "from nemo.collections.nlp.modules.common import VirtualPromptStyle\n", - "config.model.virtual_prompt_style = VirtualPromptStyle.P_TUNING" - ] - }, - { - "cell_type": "markdown", - "id": "947dec63", - "metadata": {}, - "source": [ - "Then we can set the 2 p-tuning specific parameters. Reminder, p-tuning uses an LSTM prompt encoder to predict virtual tokens. \n", - "\n", - "- **`p_tuning.dropout`** the LSTM prompt encoder dropout probability \n", - "- **`p_tuning.num_layers`** the number of LSTM layers you want your p-tuning prompt encoder to have\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03f893ef", - "metadata": {}, - "outputs": [], - "source": [ - "config.model.p_tuning.dropout = 0.0\n", - "config.model.p_tuning.num_layers = 2\n", - "config.model.global_batch_size = 2\n", - "config.model.micro_batch_size = 1" - ] - }, - { - "cell_type": "markdown", - "id": "a988d16e", - "metadata": {}, - "source": [ - "Let's have a look at all the values we've set in the model config. You can change any of these values in the same manner we've been using above. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12a37ada", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Final model config\n", - "print(OmegaConf.to_yaml(config.model))" - ] - }, - { - "cell_type": "markdown", - "id": "6b4bc7f3", - "metadata": {}, - "source": [ - "### Setting Prompt-Tuning Specific Params\n", - "\n", - "Though we are not using prompt tuning in this training session, let's go over the prompt tuning specific parameters we would use if we were. \n", - "\n", - "- **`prompt_tuning.new_prompt_init_methods`** Whether you want to initialize virtual token embeddings from the embeddings of existing parts of the model's vocabulary (either 'text' or 'random')\n", - "- **`prompt_tuning.new_prompt_init_text`** The text you want to use if you have 'text' in the list above, should be None otherwise. \n", - "\n", - "Each of the above hyperparameters are a list of strings. \n", - "\n", - "`new_prompt_init_methods` would look like `[\"text\", \"random\", \"text\", \"text\"]` if you were prompt tuning on 4 tasks at once, and you wanted the second task in `new_tasks` to use random initialization. \n", - "\n", - "`new_prompt_init_text` might look like `[\"some text I want to use\", None, \"some other text\", \"task text goes here\"]` for those four new tasks. \n", - "\n", - "The order of both should correspond to the order of the tasks you have listed in `model.new_tasks`. " - ] - }, - { - "cell_type": "markdown", - "id": "4c048852", - "metadata": {}, - "source": [ - "# Building the PyTorch Lightning Trainer\n", - "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem.\n", - "\n", - "Let's first instantiate a Trainer object" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "90f85b2a", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import pytorch_lightning as pl\n", - "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy\n", - "from pytorch_lightning.plugins.environments import TorchElasticEnvironment\n", - "\n", - "# let's modify some trainer configs\n", - "# check if we have GPU available and uses it\n", - "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", - "config.trainer.accelerator = accelerator\n", - "config.trainer.devices = 1\n", - "config.trainer.max_epochs = 4\n", - "config.trainer.val_check_interval = 1.0\n", - "\n", - "# for PyTorch Native AMP set precision=16\n", - "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n", - "\n", - "# setup cluster environment parameters\"\n", - "# use torch elastic cluster environment so `create_process_externally` is True\n", - "# the launcher is set to None. It will not try to spawn new processes.\n", - "# It won't create the misconfiguration error because of the `interactive session`\n", - "os.environ[\"LOCAL_RANK\"] = '0'\n", - "os.environ[\"RANK\"] = '0'\n", - "os.environ[\"WORLD_SIZE\"] = '1'\n", - "\n", - "strategy = NLPDDPStrategy(find_unused_parameters=False, no_ddp_communication_hook=True)\n", - "plugins = [TorchElasticEnvironment()]\n", - "trainer = pl.Trainer(plugins= plugins, strategy=strategy, **config.trainer)\n", - "\n", - "print(\"Trainer config - \\n\")\n", - "print(OmegaConf.to_yaml(config.trainer))" - ] - }, - { - "cell_type": "markdown", - "id": "4d0124c1", - "metadata": {}, - "source": [ - "# Setting up a NeMo Experiment\n", - "\n", - "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f2c943ba", - "metadata": {}, - "outputs": [], - "source": [ - "from nemo.utils.exp_manager import exp_manager\n", - "\n", - "# Set name of the experiment \n", - "config.name = 'p_tuning'\n", - "config.exp_manager.resume_if_exists = False\n", - "\n", - "# Init the experiment manager and view the exp_dir\n", - "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", - "exp_dir = str(exp_dir)\n", - "print(exp_dir)" - ] - }, - { - "cell_type": "markdown", - "id": "5860bd90", - "metadata": {}, - "source": [ - "We can also set learning hyperparameters as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4c4ec542", - "metadata": {}, - "outputs": [], - "source": [ - "# Set some of the learning parameters\n", - "config.model.optim.lr = 1e-4\n", - "config.model.precision = config.trainer.precision" - ] - }, - { - "cell_type": "markdown", - "id": "298b3dce", - "metadata": {}, - "source": [ - "# First P-Tuning Session\n", - "The only thing left to do is load up the model and begin p-tuning!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b4bda19b", - "metadata": {}, - "outputs": [], - "source": [ - "from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import MegatronGPTPromptLearningModel\n", - "\n", - "model = MegatronGPTPromptLearningModel(cfg=config.model, trainer=trainer)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2d99f433", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Training set to 2 epochs by default in a cell above\n", - "# Each epoch will take around 1min 15sec, but training time can vary\n", - "trainer.fit(model)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "6aab09d4", - "metadata": {}, - "source": [ - "# Inference After P-Tuning\n", - "One way to run inference after p-tuning or prompt-tuning your model is to call `model.generate()`. `model.generate()` takes in \n", - "\n", - "- `inputs` which can be either a list of dictionary objects or `.jsonl` files containing dictionary objects, \n", - "- `length_params`\n", - "- `sampling_params`\n", - "\n", - "as arguments. More information about the [text generation API can be found here](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/transformer/text_generation.py).\n", - "\n", - "If `length_params` and `sampling_params` are set to `None`, the model generates output with a greedy decoding strategy and generates up to `30` new tokens. Most predictive downstream tasks (not text generation tasks), use greedy sampling. To see other ways to run inference with your prompt learning model and more details on how to define various inference parameters, visit `examples/nlp/language_modeling/megatron_gpt_eval.py`.\n", - "\n", - "Below are some randomly selected test examples from the sentiment classification and intent and slot classification test files. Notice that the `label` field is dropped from all test examples. The `MegatronPromptLearningDataset` called within `.generate()` automatically leaves fields in the prompt template empty when they are not provided in the data json. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dc95e764", - "metadata": {}, - "outputs": [], - "source": [ - "test_examples = [\n", - " {\"taskname\": \"squad\", \"context\": \"The build was released for download later in the day in standard 32-bit and 64-bit versions, plus a special 64-bit version which included SDKs and developer tools (Visual Studio Express and Expression Blend) for developing Metro-style apps. The Windows Store was announced during the presentation, but was not available in this build. According to Microsoft, there were about 535,000 downloads of the developer preview within the first 12 hours of its release. Originally set to expire on March 11, 2012, in February 2012 the Developer Preview's expiry date was changed to January 15, 2013.\", \"question\": \"When was the Developer preview initially intended to expire?\"},\n", - " {\"taskname\": \"squad\", \"context\": \"The structures of most federal governments incorporate mechanisms to protect the rights of component states. One method, known as 'intrastate federalism', is to directly represent the governments of component states in federal political institutions. Where a federation has a bicameral legislature the upper house is often used to represent the component states while the lower house represents the people of the nation as a whole. A federal upper house may be based on a special scheme of apportionment, as is the case in the senates of the United States and Australia, where each state is represented by an equal number of senators irrespective of the size of its population.\", \"question\": \"What is a bicameral legislature?\"},\n", - " {\"taskname\": \"squad\", \"context\": \"Imported mystery religions, which offered initiates salvation in the afterlife, were a matter of personal choice for an individual, practiced in addition to carrying on one's family rites and participating in public religion. The mysteries, however, involved exclusive oaths and secrecy, conditions that conservative Romans viewed with suspicion as characteristic of \\\"magic\\\", conspiratorial (coniuratio), or subversive activity. Sporadic and sometimes brutal attempts were made to suppress religionists who seemed to threaten traditional morality and unity, as with the senate's efforts to restrict the Bacchanals in 186 BC.\", \"question\": \"What was the practice of religion to the Romans?\"}\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "74a5a358", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "response = model.generate(inputs=test_examples, length_params=None)\n", - "\n", - "print('The prediction results of some sample queries with the trained model:')\n", - "for result in response['sentences']:\n", - " print(result)\n", - " print(\"-\" * 30)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b7a434f4", + "metadata": {}, + "outputs": [], + "source": [ + "BRANCH='main'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "developmental-gibraltar", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies.\n", + "\"\"\"\n", + "# If you're using Google Colab and not running locally, run this cell\n", + "\n", + "# install NeMo\n", + "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[nlp]" + ] + }, + { + "cell_type": "markdown", + "id": "42daf8bf", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "In this notebook we demonstrate how to use p-tuning and prompt tuning within NeMo-Megatron. Both methods are parameter efficient alternatives to fine-tuning pretrained language models. Our NeMo implementation makes it possible to use one pretrained GPT model on many downstream tasks without needing to tune the model’s full set of parameters. It also allows for adding new tasks to your model without overwriting or disrupting previous tasks for which the model has already been p-tuned/prompt-tuned. Because the original model parameters are frozen and never altered by either method, p-tuning/prompt-tuning also avoid catastrophic forgetting issues often encountered when fine-tuning models.\n", + "\n", + "- Our prompt tuning implementation is based off Lester et. al’s EMNLP 2021 paper [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691)\n", + "\n", + "- Our p-tuning implementation is based off Liu et al's paper [GPT Understands, Too](https://arxiv.org/abs/2103.10385).\n", + "\n", + "- Command line usage examples and API documentation can be found in [our user docs](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/prompt_learning.html). \n", + "\n", + "\"Prompt\n", + "\n", + "Our continuous learning capability for combined p-tuning and prompt tuning with GPT style models is a NeMo specific extension of the author’s original work.\n", + "\n", + "# The Plan\n", + "\n", + "We are going to show you how to:\n", + " \n", + " 1. P-Tune/Prompt Tune a model on multiple tasks at the same time\n", + " 2. Add a new task to a model that has already been P-Tuned/Prompt Tuned previously\n", + " \n", + "We will first p-tune a GPT model on sentiment analysis, and intent and slot classification tasks. Then we will show how to add the squad question answering task to the same model we already p-tuned once.\n", + "\n", + "\n", + "# Technical Overview\n", + "Instead of selecting discrete text prompts in a manual or automated fashion, prompt tuning and p-tuning utilize virtual prompt embeddings that can be optimized via gradient decent. The only difference between prompt tuning and p-tuning within NeMo-Megatron is the architecture used to tune the soft prompt tokens during training.\n", + "\n", + "### Terminology\n", + "We will be using the terms `continuous`, `soft`, and `virtual` token interchangeably to refer to embeddings inserted into the model prompt that have no concrete mapping to strings or characters within the model’s vocabulary. These virtual token embeddings exist in contrast to the `discrete`, `hard`, or `real` tokens that do make up the model’s vocabulary. Virtual tokens are purely 1D vectors with dimensionality equal to that of each real token embedding, matching the `hidden_size` hyperparameter. In training and inference, continuous token embeddings are inserted among discrete token embeddings according to a template you provide in the model’s config. We will demonstrate how to do this below.\n", + "\n", + "When referring to p-tuning and prompt tuning together, we will be using the phrase prompt learning for simplicity.\n", + "\n", + "### Prompt-Tuning\n", + "In prompt-tuning a pretrained GPT model, soft prompt embeddings are initialized as a 2D matrix of size `total_virtual_tokens X hidden_size`. Each task the model is prompt-tuned to perform has its own 2D embedding matrix associated with it. Tasks do not share any parameters during training or inference. All GPT model parameters are frozen and only the embedding parameters for each task are updated during training.\n", + "\n", + "In prompt tuning you can specify how the embeddings are initialized for each task. You can either\n", + "\n", + "1. Initialize embedding parameters according to some random distribution\n", + "2. Initialize embedding parameters from existing vocabulary embeddings (recommended)\n", + "\n", + "If you choose to initialize virtual token embeddings from existing embedding weights, you can provide the string of words you want to use for initialization in the model’s config. This string will be tokenized and tiled or truncated to match the specified number of virtual tokens you would like to use (`total_virtual_tokens`). Vocab embeddings are copied and used to initialize the soft prompt embedding matrix for each task. The vocab embeddings themselves are not updated or changed during prompt tuning.\n", + "\n", + "\n", + "### P-Tuning\n", + "In p-tuning, an LSTM model is used to predict virtual token embeddings. We refer to this LSTM model as our `prompt_encoder`. LSTM parameters are randomly initialized at the start of p-tuning. All GPT model parameters are frozen, and only the LSTM weights are updated at each training step. LSTM parameters are shared between all tasks that are p-tuned at the same time, but the LSTM model outputs unique virtual token embeddings for each task. The virtual tokens predicted by the LSTM are inserted among the discrete token input in the exact same manner as with prompt-tuning. You still specify the number of virtual tokens you want to use by setting `total_virtual_tokens` and each virtual token embedding is still a 1D vector of size `hidden_size`.\n", + "\n", + "\n", + "\n", + "# The Best of Both\n", + "A single pretrained GPT model can use both p-tuning and prompt-tuning. While you must decide to use either p-tuning or prompt-tuning for each task you want your model to perform, you can p-tune your model on a set of tasks A, then prompt tune your same model on a different set of tasks B, then finally run inference on tasks from both A and B at the same time. During prompt-tuning or p-tuning, tasks tuned at the same time must use the same number of virtual tokens. During inference, tasks using differing amounts of virtual tokens can be run at the same time.\n", + "\n", + "Please see our [docs for more comparisons between prompt and p-tuning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/prompt_learning.html). \n", + "\n", + "With all that covered, let's get started!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31c27562", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import wget" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0bfc7709", + "metadata": {}, + "source": [ + "# Tasks and Datasets\n", + "We will be using p-tuning to teach our GPT model to do **Question Answering**.\n", + "\n", + "We will be using the [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) reading comprehension dataset, consisting of questions posed by crowd workers on a set of Wikipedia articles, where the answer to every question is a segment of text. More information on [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) can be found on their website or in their paper by Rajpurkar et. al \"[Know What You Don’t Know: Unanswerable Questions for SQuAD](https://arxiv.org/pdf/1806.03822.pdf)\"." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "e0b0072a", + "metadata": {}, + "source": [ + "# Data Preparation\n", + "\n", + "The prompt learning dataset loader accepts a list of json/dictionary objects or a list of json file names where each json file contains a collection of json objects. Each json object must include the field `taskname` which is a string identifier for the task the data example corresponds to. They should also include one or more fields corresponding to different sections of the discrete text prompt. The input data might look like:\n", + "\n", + "```\n", + "[\n", + " {\"taskname\": \"squad\", \"context\": [CONTEXT_PARAGRAPH_TEXT1], \"question\": [QUESTION_TEXT1], \"answer\": [ANSWER_TEXT1]},\n", + " {\"taskname\": \"squad\", \"context\": [CONTEXT_PARAGRAPH_TEXT2], \"question\": [QUESTION_TEXT2], \"answer\": [ANSWER_TEXT2]},\n", + "]\n", + "```\n", + "\n", + "These additional fields can be unlimited in number and will be used to help map different parts of the discrete text input to a prompt template that you define. We will show how this mapping works and how to construct your prompt template in the `Prompt Formatting` section. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dbd41fd", + "metadata": {}, + "outputs": [], + "source": [ + "# You can replace DATA_DIR and NEMO_DIR with your own locations\n", + "DATA_DIR = \"data\"\n", + "NEMO_DIR = '.'\n", + "\n", + "os.makedirs(DATA_DIR, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "504a7b40", + "metadata": {}, + "source": [ + "\n", + "For each dataset we have preprocessing scripts pre-written in NeMo's example directory located in `examples/nlp`. Let's download those now. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e72a1dc1", + "metadata": {}, + "outputs": [], + "source": [ + "# download the preprocessing scripts from github for the purpose of this tutorial\n", + "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/scripts/dataset_processing/nlp/squad/prompt_learning_squad_preprocessing.py', NEMO_DIR)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "71813919", + "metadata": {}, + "source": [ + "Now let's down load and process the dataset." + ] + }, + { + "cell_type": "markdown", + "id": "816791de", + "metadata": {}, + "source": [ + "### SQuAD Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa16d8ac", + "metadata": {}, + "outputs": [], + "source": [ + "SQUAD_DIR = os.path.join(DATA_DIR, \"SQuAD\")\n", + "os.makedirs(SQUAD_DIR, exist_ok=True)\n", + "\n", + "# Download the SQuAD dataset\n", + "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json\n", + "!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json\n", + "!mv train-v1.1.json {SQUAD_DIR}\n", + "!mv dev-v1.1.json {SQUAD_DIR}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64e3e25b", + "metadata": {}, + "outputs": [], + "source": [ + "# Preprocess squad data\n", + "!python $NEMO_DIR/prompt_learning_squad_preprocessing.py --data-dir {SQUAD_DIR}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b562d1de", + "metadata": {}, + "outputs": [], + "source": [ + "# What the squad dataset looks like after processing\n", + "!head -4 $SQUAD_DIR/squad_train.jsonl" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a385d319", + "metadata": {}, + "source": [ + "We made a `.jsonl` file for each of the train, validation, and testing splits of the squad data. Every `.jsonl` file contains json objects with the fields `taskname`, `context`, `question`, and `answer`. The preprocessing script is called `prompt_learning_squad_preprocessing.py`. It should be in your `NEMO_DIR` and at `scripts/dataset_processing/nlp/squad/prompt_learning_squad_preprocessing.py` in the NeMo repo. \n", + "\n", + "The SQuAD dataset consists of various topics like `Beyoncé`, `IPod`, and `Symbiosis`. Each topic has several paragraphs associated with it, and each paragraph has several questions and answers related to it. When we separated the train/validation/test splits, we separated them on the topic level. For example, if the training set contains paragraphs and questions about the topic `Beyoncé`, neither the validation nor test sets will contain any questions on this topic. All questions about a certain topic are isolated to one split of the data. \n", + "\n", + "Like the Financial PhraseBank Dataset, we randomly selected 80% of the questions for training, 10% for validation, and 10% for test. This resulted in `69125` test examples, `8952` validation examples, and `8744` testing examples. The `answer` field was removed from test examples.\n", + "\n", + "Training on the full train split could take a lot of time, so we are going to clip the train split to 2k examples for the sake of this tutorial, and limit the validation dataset to 200 samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f1473ba", + "metadata": {}, + "outputs": [], + "source": [ + "! head -2000 $SQUAD_DIR/squad_train.jsonl > $SQUAD_DIR/squad_short_train.jsonl\n", + "! head -200 $SQUAD_DIR/squad_val.jsonl > $SQUAD_DIR/squad_short_val.jsonl\n" + ] + }, + { + "cell_type": "markdown", + "id": "2e19c8dc", + "metadata": {}, + "source": [ + "# P-Tuning Model Config Setup\n", + "\n", + "Now we will begin setting up the config file used for prompt/p-tuning our GPT models! GPT Prompt learning within NeMo uses a class called `MegatronGPTPromptLearningModel` which has its own config file. We will start by loading an example prompt learning config file, then make changes to it to fit our tasks and training plans. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5749c387", + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf\n", + "\n", + "CONFIG_DIR = os.path.join(NEMO_DIR, \"conf\")\n", + "os.makedirs(CONFIG_DIR, exist_ok=True)\n", + "\n", + "# Download the example config file\n", + "wget.download(f'https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/language_modeling/conf/megatron_gpt_prompt_learning_config.yaml', CONFIG_DIR)\n", + "\n", + "# Load the example config file so we can start editing it\n", + "CONFIG_PATH = os.path.join(CONFIG_DIR, \"megatron_gpt_prompt_learning_config.yaml\")\n", + "config = OmegaConf.load(CONFIG_PATH)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ce966bcf", + "metadata": {}, + "source": [ + "First let's set the datasets we've created in the config. We are going to start by p-tuning a GPT model on a small subset of the **Squad** task. We do this by setting the following config params below: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bb1590f", + "metadata": {}, + "outputs": [], + "source": [ + "config.model.data.train_ds = [f\"{SQUAD_DIR}/squad_short_train.jsonl\"]\n", + "config.model.data.validation_ds = [f\"{SQUAD_DIR}/squad_short_val.jsonl\"]" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "4e021b24", + "metadata": {}, + "source": [ + "### Prompt Formatting\n", + "Now that we have our dataset, lets define what we want the prompt to look like. \n", + "\n", + "The squad dataset json files contain fields named \"context\", \"question\" and \"answer\". The prompt formatting template allows us to arrange these fields and decide where to insert virtual prompts. We can add the `<|VIRTUAL_PROMPT_0|>` token anywere between the fields (although we recommend simply adding it in the leftmost position will be sufficient).\n", + "\n", + "For example, given a data jsonl file with examples like this: \n", + "\n", + "\n", + "**{\"taskname\": \"squad\", \"context\": \"Super Bowl 50 was an American football ga... numerals 50.\", \"question\": \"What does AFC stand for?\", \"answer\": \"American Football Conference\"}**. \n", + "\n", + "\n", + "We can create a prompt template set to `prompt_template = \"<|VIRTUAL_PROMPT_0|> Context: {context}\\n\\nquestion: {question}\\n\\nanswer: {answer}\"` other options are also possible, for example the `\\n` can be replaced with whitespace or the other of the context and question can be swapped. The answer however, should be at the end.\n", + "\n", + "Let's configure the prompt template for the task below:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f935b411", + "metadata": {}, + "outputs": [], + "source": [ + "config.model.task_templates = [\n", + " \n", + " {\n", + " \"taskname\": \"squad\",\n", + " \"prompt_template\": \"<|VIRTUAL_PROMPT_0|> Context: {context}\\n\\nQuestion: {question}\\n\\nAnswer:{answer}\",\n", + " \"total_virtual_tokens\": 15,\n", + " \"virtual_token_splits\": [15],\n", + " \"truncate_field\": \"context\",\n", + " \"answer_only_loss\": True,\n", + " \"answer_field\": \"answer\",\n", + " },\n", + " \n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "dcc438b5", + "metadata": {}, + "source": [ + "Note each `task_template` item has 5 fields. \n", + "\n", + "- **`prompt_template`** is a string showing the model where to place virtual tokens and how to map dataset json fields to where they belong in the model prompt. \n", + "\n", + "\n", + "- **`taskname`** refers to the same `taskname` in the dataset json objects. \n", + "\n", + "\n", + "- **`total_virtual_tokens`** specifies the total number of virtual tokens that will be inserted into the model prompt.\n", + "\n", + "\n", + "- **`virtual_token_splits`** specifies the number of virtual tokens that belong at each `<|VIRTUAL_PROMPT_#|>` marker. `virtual_token_splits` values should add up to `total_virtual_tokens`. The number of `virtual_token_splits` should match the number of `<|VIRTUAL_PROMPT_#|>` markers. \n", + "\n", + "\n", + "- **`truncate_field`** specifies which field in the data json to truncate if the length of the input exceeds the maximum sequence length of the model. If `truncate_field` is set to `None`, examples that are too long are simply dropped from the dataset.\n", + "\n", + "\n", + "- **`answer_only_loss`** Whether to limit loss calculation to only the answer portion of the prompt during tuning. `True` Strongly recommended for long prompts, but shorter prompts with single word answers seem to benefit from setting this to `False`. \n", + "\n", + "\n", + "- **`answer_field`** The field in the data json corresponding to the answer. The loss will only be calculated on this portion of the prompt if `answer_only_loss` is `True`. The answer field must be at the end of the prompt template.\n", + "\n", + "In the `task_templates` we set above, `squad` has a different number of virtual tokens than `sentiment` and `intent_and_slot`. This is because we will be p-tuning on `squad` after we p-tune on the other two tasks and **we do not need to use the same number of virtual tokens between sessions**. We also set the `truncate` field for squad because the context can sometimes be longer than the model's max sequence length, and we want that field to be truncated if the example is too long. Lastly, we set `answer_only_loss` to true for `squad` due to the longer prompt. We've found `answer_only_loss=True` to work significantly better for this task." + ] + }, + { + "cell_type": "markdown", + "id": "84579c7a", + "metadata": {}, + "source": [ + "### Setting New Tasks\n", + "After you p-tune your model this time, you can always go back and p-tune or prompt-tune your model on more tasks without over writing the virtual prompts who've trained this time. You can also use a different number of `total_virtual_tokens` between each training session as long as tasks p-tuned or prompt tuned at the same time have the same number of `total_virtual_tokens`. For this reason, when you p-tune on a new task, you need to tell your model which of your tasks are new and which ones already exist (and thus you don't want to tune them). \n", + "\n", + "You do this by setting the `new_tasks` and `existing_tasks` values in the config file. Because we are p-tuning a model with no existing tasks, you should set `existing_tasks=[]` and `new_tasks=[\"sentiment\", \"intent_and_slot\"]` as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57a73e01", + "metadata": {}, + "outputs": [], + "source": [ + "config.model.existing_tasks = []\n", + "config.model.new_tasks = [\"squad\"]" + ] + }, + { + "cell_type": "markdown", + "id": "3b77e88c", + "metadata": {}, + "source": [ + "After p-tuning and/or prompt tuning is complete, you can run inference on all tasks at the same time, regardless of their `total_virtual_tokens` value." + ] + }, + { + "cell_type": "markdown", + "id": "a0d5017e", + "metadata": {}, + "source": [ + "### Setting The Pre-Trained GPT Model\n", + "We still need to set which GPT model we want to p-tune/prompt tune. Prompt learning methods work best with large GPT language models (5B or above), but the purposes of this tutorial, we are going to download a 345M parameter GPT model from NVIDIA NGC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48cdf868", + "metadata": {}, + "outputs": [], + "source": [ + "# Check what GPT .nemo models we have available on NGC\n", + "from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel\n", + "MegatronGPTModel.list_available_models()" + ] + }, + { + "cell_type": "markdown", + "id": "ede350ed", + "metadata": {}, + "source": [ + "If we wanted to use the GPT model class directly, we could instantiate a trainer then download the model by calling running \n", + "`gpt_model = MegatronGPTModel.from_pretrained(model_name=\"megatron_gpt_345m\", trainer=trainer).cuda()`. But we just need the `.nemo` file in our working NeMo directory in this tutorial, so we will download it using `wget`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "364439a1", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Download the model from NGC\n", + "gpt_file_name = \"megatron_gpt_345m.nemo\"\n", + "!wget -nc --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/nemo/megatron_gpt_345m/versions/1/files/megatron_gpt_345m.nemo -O {NEMO_DIR}/{gpt_file_name}" + ] + }, + { + "cell_type": "markdown", + "id": "1d6a8a67", + "metadata": {}, + "source": [ + "Now that we have a `.nemo` GPT file to work with. We need to add its path in our prompt learning config. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2778a5fa", + "metadata": {}, + "outputs": [], + "source": [ + "# Set GPT model path on prompt learning config\n", + "config.model.language_model_path = gpt_file_name" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "943a9c83", + "metadata": {}, + "source": [ + "We can also set where we want the final prompt tuned model to be saved by setting `model.nemo_path`. By default the tuned prompt learning model will be saved in your current working directory to a `.nemo` file with the same name as your experiment (`config.name`). Let's change the save name to be `p_tuned_gpt.nemo`. **Your model path must end in `.nemo`.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a278cbdf", + "metadata": {}, + "outputs": [], + "source": [ + "config.exp_manager.checkpoint_callback_params.save_nemo_on_train_end= True\n", + "config.exp_manager.checkpoint_callback_params.always_save_nemo= True\n", + "config.exp_manager.checkpoint_callback_params.save_best_model= True" + ] + }, + { + "cell_type": "markdown", + "id": "378a73e7", + "metadata": {}, + "source": [ + "### Setting P-Tuning Specific Params\n", + "Within the config file, p-tuning and prompt-tuning each have a couple of hyperparameters specific to them. We first need to tell the model that we want to do p-tuning, not prompt-tuning. To do this, we set the **`model.virtual_prompt_style`** hyperparameter like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68763763", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.nlp.modules.common import VirtualPromptStyle\n", + "config.model.virtual_prompt_style = VirtualPromptStyle.P_TUNING" + ] + }, + { + "cell_type": "markdown", + "id": "947dec63", + "metadata": {}, + "source": [ + "Then we can set the 2 p-tuning specific parameters. Reminder, p-tuning uses an LSTM prompt encoder to predict virtual tokens. \n", + "\n", + "- **`p_tuning.dropout`** the LSTM prompt encoder dropout probability \n", + "- **`p_tuning.num_layers`** the number of LSTM layers you want your p-tuning prompt encoder to have\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03f893ef", + "metadata": {}, + "outputs": [], + "source": [ + "config.model.p_tuning.dropout = 0.0\n", + "config.model.p_tuning.num_layers = 2\n", + "config.model.global_batch_size = 2\n", + "config.model.micro_batch_size = 1" + ] + }, + { + "cell_type": "markdown", + "id": "a988d16e", + "metadata": {}, + "source": [ + "Let's have a look at all the values we've set in the model config. You can change any of these values in the same manner we've been using above. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12a37ada", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Final model config\n", + "print(OmegaConf.to_yaml(config.model))" + ] + }, + { + "cell_type": "markdown", + "id": "6b4bc7f3", + "metadata": {}, + "source": [ + "### Setting Prompt-Tuning Specific Params\n", + "\n", + "Though we are not using prompt tuning in this training session, let's go over the prompt tuning specific parameters we would use if we were. \n", + "\n", + "- **`prompt_tuning.new_prompt_init_methods`** Whether you want to initialize virtual token embeddings from the embeddings of existing parts of the model's vocabulary (either 'text' or 'random')\n", + "- **`prompt_tuning.new_prompt_init_text`** The text you want to use if you have 'text' in the list above, should be None otherwise. \n", + "\n", + "Each of the above hyperparameters are a list of strings. \n", + "\n", + "`new_prompt_init_methods` would look like `[\"text\", \"random\", \"text\", \"text\"]` if you were prompt tuning on 4 tasks at once, and you wanted the second task in `new_tasks` to use random initialization. \n", + "\n", + "`new_prompt_init_text` might look like `[\"some text I want to use\", None, \"some other text\", \"task text goes here\"]` for those four new tasks. \n", + "\n", + "The order of both should correspond to the order of the tasks you have listed in `model.new_tasks`. " + ] + }, + { + "cell_type": "markdown", + "id": "4c048852", + "metadata": {}, + "source": [ + "# Building the PyTorch Lightning Trainer\n", + "NeMo models are primarily PyTorch Lightning modules - and therefore are entirely compatible with the PyTorch Lightning ecosystem.\n", + "\n", + "Let's first instantiate a Trainer object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90f85b2a", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import pytorch_lightning as pl\n", + "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy\n", + "from pytorch_lightning.plugins.environments import TorchElasticEnvironment\n", + "\n", + "# let's modify some trainer configs\n", + "# check if we have GPU available and uses it\n", + "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", + "config.trainer.accelerator = accelerator\n", + "config.trainer.devices = 1\n", + "config.trainer.max_epochs = 4\n", + "config.trainer.val_check_interval = 1.0\n", + "\n", + "# for PyTorch Native AMP set precision=16\n", + "config.trainer.precision = 16 if torch.cuda.is_available() else 32\n", + "\n", + "# setup cluster environment parameters\"\n", + "# use torch elastic cluster environment so `create_process_externally` is True\n", + "# the launcher is set to None. It will not try to spawn new processes.\n", + "# It won't create the misconfiguration error because of the `interactive session`\n", + "os.environ[\"LOCAL_RANK\"] = '0'\n", + "os.environ[\"RANK\"] = '0'\n", + "os.environ[\"WORLD_SIZE\"] = '1'\n", + "\n", + "strategy = NLPDDPStrategy(find_unused_parameters=False, no_ddp_communication_hook=True)\n", + "plugins = [TorchElasticEnvironment()]\n", + "trainer = pl.Trainer(plugins= plugins, strategy=strategy, **config.trainer)\n", + "\n", + "print(\"Trainer config - \\n\")\n", + "print(OmegaConf.to_yaml(config.trainer))" + ] + }, + { + "cell_type": "markdown", + "id": "4d0124c1", + "metadata": {}, + "source": [ + "# Setting up a NeMo Experiment\n", + "\n", + "NeMo has an experiment manager that handles logging and checkpointing for us, so let's use it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2c943ba", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.utils.exp_manager import exp_manager\n", + "\n", + "# Set name of the experiment \n", + "config.name = 'p_tuning'\n", + "config.exp_manager.resume_if_exists = False\n", + "\n", + "# Init the experiment manager and view the exp_dir\n", + "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n", + "exp_dir = str(exp_dir)\n", + "print(exp_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "5860bd90", + "metadata": {}, + "source": [ + "We can also set learning hyperparameters as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c4ec542", + "metadata": {}, + "outputs": [], + "source": [ + "# Set some of the learning parameters\n", + "config.model.optim.lr = 1e-4\n", + "config.model.precision = config.trainer.precision" + ] + }, + { + "cell_type": "markdown", + "id": "298b3dce", + "metadata": {}, + "source": [ + "# First P-Tuning Session\n", + "The only thing left to do is load up the model and begin p-tuning!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4bda19b", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import MegatronGPTPromptLearningModel\n", + "\n", + "model = MegatronGPTPromptLearningModel(cfg=config.model, trainer=trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d99f433", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Training set to 2 epochs by default in a cell above\n", + "# Each epoch will take around 1min 15sec, but training time can vary\n", + "trainer.fit(model)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "6aab09d4", + "metadata": {}, + "source": [ + "# Inference After P-Tuning\n", + "One way to run inference after p-tuning or prompt-tuning your model is to call `model.generate()`. `model.generate()` takes in \n", + "\n", + "- `inputs` which can be either a list of dictionary objects or `.jsonl` files containing dictionary objects, \n", + "- `length_params`\n", + "- `sampling_params`\n", + "\n", + "as arguments. More information about the [text generation API can be found here](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/transformer/text_generation.py).\n", + "\n", + "If `length_params` and `sampling_params` are set to `None`, the model generates output with a greedy decoding strategy and generates up to `30` new tokens. Most predictive downstream tasks (not text generation tasks), use greedy sampling. To see other ways to run inference with your prompt learning model and more details on how to define various inference parameters, visit `examples/nlp/language_modeling/megatron_gpt_eval.py`.\n", + "\n", + "Below are some randomly selected test examples from the sentiment classification and intent and slot classification test files. Notice that the `label` field is dropped from all test examples. The `MegatronPromptLearningDataset` called within `.generate()` automatically leaves fields in the prompt template empty when they are not provided in the data json. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc95e764", + "metadata": {}, + "outputs": [], + "source": [ + "test_examples = [\n", + " {\"taskname\": \"squad\", \"context\": \"The build was released for download later in the day in standard 32-bit and 64-bit versions, plus a special 64-bit version which included SDKs and developer tools (Visual Studio Express and Expression Blend) for developing Metro-style apps. The Windows Store was announced during the presentation, but was not available in this build. According to Microsoft, there were about 535,000 downloads of the developer preview within the first 12 hours of its release. Originally set to expire on March 11, 2012, in February 2012 the Developer Preview's expiry date was changed to January 15, 2013.\", \"question\": \"When was the Developer preview initially intended to expire?\"},\n", + " {\"taskname\": \"squad\", \"context\": \"The structures of most federal governments incorporate mechanisms to protect the rights of component states. One method, known as 'intrastate federalism', is to directly represent the governments of component states in federal political institutions. Where a federation has a bicameral legislature the upper house is often used to represent the component states while the lower house represents the people of the nation as a whole. A federal upper house may be based on a special scheme of apportionment, as is the case in the senates of the United States and Australia, where each state is represented by an equal number of senators irrespective of the size of its population.\", \"question\": \"What is a bicameral legislature?\"},\n", + " {\"taskname\": \"squad\", \"context\": \"Imported mystery religions, which offered initiates salvation in the afterlife, were a matter of personal choice for an individual, practiced in addition to carrying on one's family rites and participating in public religion. The mysteries, however, involved exclusive oaths and secrecy, conditions that conservative Romans viewed with suspicion as characteristic of \\\"magic\\\", conspiratorial (coniuratio), or subversive activity. Sporadic and sometimes brutal attempts were made to suppress religionists who seemed to threaten traditional morality and unity, as with the senate's efforts to restrict the Bacchanals in 186 BC.\", \"question\": \"What was the practice of religion to the Romans?\"}\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74a5a358", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "response = model.generate(inputs=test_examples, length_params=None)\n", + "\n", + "print('The prediction results of some sample queries with the trained model:')\n", + "for result in response['sentences']:\n", + " print(result)\n", + " print(\"-\" * 30)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 + } \ No newline at end of file