Skip to content

Commit

Permalink
Ensure fine-tuning/prompt learning work for T5 (#6385)
Browse files Browse the repository at this point in the history
Signed-off-by: SeanNaren <snarenthiran@nvidia.com>
  • Loading branch information
SeanNaren authored Apr 7, 2023
1 parent df1d5d1 commit a3db3aa
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Dict, List

import torch
from omegaconf import DictConfig, ListConfig
Expand All @@ -22,17 +23,23 @@
from nemo.collections.common.metrics.classification_accuracy import ExactStringPerCategoryMatchMetric
from nemo.collections.nlp.data.common.sequence_to_sequence_dataset import SequenceToSequenceDataset
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model, T5Sentinel
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.utils import AppState, logging

try:
from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_num_microbatches
from apex.transformer.pipeline_parallel.utils import (
_reconfigure_microbatch_calculator,
get_micro_batch_size,
get_num_microbatches,
)

HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

try:
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -138,11 +145,6 @@ def setup(self, stage=None):
if hasattr(self, '_train_ds'):
self.setup_training_data()

def _process_global_batch(self, global_batch):
"""Optionally processes a global batch."""
# TODO: maybe remove this now that we've refactored data batch sizes.
return global_batch

def on_validation_epoch_start(self):
app_state = AppState()
_reconfigure_microbatch_calculator(
Expand Down Expand Up @@ -268,36 +270,72 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config):
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)

processed_batch = self._process_global_batch(batch)
return processed_batch
def fwd_bwd_step(self, batch, batch_idx, forward_only):
"""
Dataloader produces a global batch which is turned into a list of microbatches.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
"""
# Get seq length of batch
_, seq_length = batch[0].shape
_, dec_seq_length = batch[1].shape
tensor_shape = [seq_length, get_micro_batch_size(), self.hidden_size]
data_iter = get_iterator_k_split(batch, get_num_microbatches())

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=data_iter,
model=[self.enc_dec_model],
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
tensor_shape=tensor_shape,
decoder_seq_length=dec_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
)

# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
# we're not on the last pipeline stage so no losses
loss_mean = torch.tensor(0.0).cuda()

return loss_mean

def inference_step(self, batch, batch_idx, mode, dataloader_idx=0):
# Regular finetuning datasets will return a list of dicts for each microbatch. But T0 datasets will return a single dict for the global batch.
def inference_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, mode: str, dataloader_idx=0):
# Regular finetuning datasets will return a list of dicts for each microbatch.
# But T0 datasets will return a single dict for the global batch.
batch_has_lang_information = isinstance(batch, list) and len(batch[0]) == 7
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds

processed_batch = self._reconfigure_and_process_inference_batch(batch, data_cfg)
self._reconfigure_and_process_inference_batch(batch, data_cfg)

# Call parent validation step to get the loss.
# NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class.
loss = super().validation_step(processed_batch, batch_idx)
# NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI,
# this will be ignored.
loss = self.fwd_bwd_step(self._process_batch(batch), batch_idx, forward_only=True)

predicted_token_ids, _ = self.decode(
tokens_enc=processed_batch['text_enc'],
enc_mask=processed_batch['enc_mask'],
tokens_enc=batch['text_enc'],
enc_mask=batch['enc_mask'],
num_tokens_to_generate=30,
bos_id=self.tokenizer.pad_id if data_cfg.replace_bos_with_pad else self.tokenizer.bos_id,
)

# Special ids to text function to handle stripping <eos> and special tokens with sentencepiece tokenizers.
preds_text = MegatronT5FinetuneModel.ids_to_text(predicted_token_ids, self.tokenizer)
labels_text = MegatronT5FinetuneModel.ids_to_text(processed_batch['labels'], self.tokenizer)
input_text = MegatronT5FinetuneModel.ids_to_text(processed_batch['text_enc'], self.tokenizer)
labels_text = MegatronT5FinetuneModel.ids_to_text(batch['labels'], self.tokenizer)
input_text = MegatronT5FinetuneModel.ids_to_text(batch['text_enc'], self.tokenizer)

if not batch_has_lang_information:
categories = [None] * len(preds_text)
else:
categories = processed_batch['lang']
categories = batch['lang']

metric = self.val_metric[dataloader_idx] if mode == 'validation' else self.test_metric[dataloader_idx]
assert len(categories) == len(preds_text) == len(labels_text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,32 @@

import itertools
import os
import re
from functools import partial
from typing import Any, List, Optional, Union

import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer
from torch import Tensor

from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset
from nemo.collections.nlp.metrics.prompt_learning_metrics import AccuracyScore, BLEUScore, ROUGEScores
from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel
from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import (
MegatronBasePromptLearningModel,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common import (
PromptEncoder,
PromptEncoderType,
VirtualPromptPlaceholderToken,
VirtualPromptSource,
VirtualPromptStyle,
from nemo.collections.nlp.modules.common import VirtualPromptPlaceholderToken, VirtualPromptSource, VirtualPromptStyle
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
get_iterator_k_split,
)
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.collections.nlp.modules.common.text_generation_utils import (
get_default_length_params,
get_default_sampling_params,
megatron_gpt_generate,
)
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam, TextGeneration
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.utils import AppState, logging
Expand Down Expand Up @@ -301,7 +294,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
# Get seq length of batch
_, seq_length = batch[0].shape
tensor_shape = [seq_length, get_micro_batch_size(), self.hidden_size]
data_iter = self.get_iterator_k_split(batch, get_num_microbatches())
data_iter = get_iterator_k_split(batch, get_num_microbatches())

fwd_bwd_function = get_forward_backward_func()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,26 @@ def allreduce_word_and_position_embeddings(self):
grad, group=parallel_state.get_decoder_relative_position_embedding_group()
)

def _process_batch(self, global_batch: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
# If the decoder input starts with <pad> instead of <bos>, which is the case for huggingface T5 models, we don't want to mask the first token.
# For NeMo-Megatron, the sequence starts with <bos>, which is never masked so we can always set index 0 to be unmasked.
global_batch['dec_mask'][:, 0] = 1

return [
global_batch["text_enc"],
global_batch["text_dec"],
global_batch["loss_mask"],
global_batch["labels"],
global_batch["enc_mask"],
global_batch["dec_mask"],
]

def get_forward_output_and_loss_func(self):
def fwd_output_and_loss_func(dataloader_iter, model):
batch = next(dataloader_iter)
batch = self._process_training_batch(batch)
if isinstance(batch, dict):
# convert to list if not already converted.
batch = self._process_batch(batch)
batch = [x.cuda(non_blocking=True) for x in batch]
encoder_input_ids, decoder_input_ids, loss_mask, lm_labels, encoder_attn_mask, decoder_attn_mask = batch

Expand Down Expand Up @@ -769,21 +785,6 @@ def _process_global_batch_without_megatron_batch_sampler(self, global_batch, tok
'dec_mask': dec_mask_tensor,
}

def _process_training_batch(self, batch: Dict[str, torch.Tensor]) -> List[torch.Tensor]:
# If the decoder input starts with <pad> instead of <bos>, which is the case for huggingface T5 models, we don't want to mask the first token.
# For NeMo-Megatron, the sequence starts with <bos>, which is never masked so we can always set index 0 to be unmasked.
batch['dec_mask'][:, 0] = 1
# Megatron Core expects a data iterator to iterate over the micro-batches.
# split the global batch into smaller micro-batches.
return [
batch["text_enc"],
batch["text_dec"],
batch["loss_mask"],
batch["labels"],
batch["enc_mask"],
batch["dec_mask"],
]

def build_train_valid_test_datasets(self):
raise NotImplementedError("Please implement this method in child-class")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
)
from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model
from nemo.collections.nlp.modules.common import VirtualPromptSource
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
get_iterator_k_split,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.utils import AppState, logging
Expand Down Expand Up @@ -179,7 +181,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
_, seq_length = batch[0].shape
_, dec_seq_length = batch[1].shape
tensor_shape = [seq_length, get_micro_batch_size(), self.hidden_size]
data_iter = self.get_iterator_k_split(batch, get_num_microbatches())
data_iter = get_iterator_k_split(batch, get_num_microbatches())

fwd_bwd_function = get_forward_backward_func()

Expand Down
11 changes: 9 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

"""Utilities for models."""

import itertools
import math
from typing import Dict, List, Tuple, Union
from typing import Dict, Iterator, List, Tuple, Union

import torch

Expand Down Expand Up @@ -366,3 +366,10 @@ def get_all_params_for_weight_decay_optimization(
]

return ({'params': weight_decay_params},)


def get_iterator_k_split(batch: List[torch.Tensor], microbatches: int) -> Iterator:
assert batch[0].shape[0] % microbatches == 0, "Issue with batch size configuration!"
split_batch = [torch.tensor_split(item, microbatches, dim=0) for item in batch]
microbatches = [[elem[i] for elem in split_batch] for i in range(microbatches)]
return itertools.chain(microbatches)

0 comments on commit a3db3aa

Please sign in to comment.