From 789d7e474ae9b09b77b4cf3c7a5a74061b22419b Mon Sep 17 00:00:00 2001 From: Valerie Sarge Date: Thu, 25 Jan 2024 10:44:02 -0800 Subject: [PATCH] Correct padding for SFT input data to account for sequence parallel + TE's fp8 op dimension requirements (#8240) * Alter GPTSFTDataset / GPTSFTPackedDataset to account for SP when padding sequences to ensure divisibility by 8/16 for TE with fp8 Signed-off-by: Valerie Sarge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Valerie Sarge Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Pablo Garay --- .../nlp/data/language_modeling/megatron/gpt_sft_dataset.py | 6 ++++-- .../nlp/models/language_modeling/megatron_gpt_sft_model.py | 7 +++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 10608dbf656b..ed84eca6e7a8 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -39,6 +39,7 @@ def __init__( tokenizer: TokenizerSpec, max_seq_length: int = 1024, min_seq_length: int = 1, + pad_seq_length_to_mult: int = 16, add_bos: bool = False, add_eos: bool = True, add_sep: bool = False, @@ -88,6 +89,7 @@ def __init__( self.file_path = file_path self.max_seq_length = max_seq_length self.min_seq_length = min_seq_length + self.pad_seq_length_to_mult = pad_seq_length_to_mult self.add_bos = add_bos self.add_eos = add_eos self.add_sep = add_sep @@ -440,7 +442,7 @@ def collate_fn(self, batch): if self.pad_to_max_length: max_length = self.max_seq_length else: - max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) + max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult)) assert max_length <= self.max_seq_length attention_mask = [self._create_attention_mask(max_length) for _ in batch] @@ -534,7 +536,7 @@ def collate_fn(self, batch): # for many datasets in practice, all packed sequence lengths are very close to the # target length (2048, 4096, 8192), so there is very minimal padding max_length = max(len(l) for l in input_ids) - max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) + max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult)) assert max_length <= self.max_seq_length position_ids: List[List[int]] = [] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 14d043d10d8c..1f92c88cb774 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -252,6 +252,12 @@ def _build_dataset(self, data_cfg, is_train=True): ) data_cfg.max_seq_length = self.cfg.max_position_embeddings + # TE requires that the first input dim is divisible by 8 and the second by 16 for fp8 + # When using sequence parallel, sequence will further be split by TP size + pad_seq_length_to_mult = ( + 8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16 + ) + for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): if self.cfg.data.get("chat", False): dataset_cls = GPTSFTChatDataset @@ -265,6 +271,7 @@ def _build_dataset(self, data_cfg, is_train=True): tokenizer=self.tokenizer, max_seq_length=data_cfg.max_seq_length, min_seq_length=data_cfg.min_seq_length, + pad_seq_length_to_mult=pad_seq_length_to_mult, add_bos=data_cfg.get('add_bos', False), add_eos=data_cfg.get('add_eos', True), add_sep=data_cfg.get('add_sep', False),