diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 8b73e158e898a..ba5d60ebb427c 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -554,6 +554,12 @@ def collate_fn(self, batch): loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0) position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0) + # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. + cu_seqlens = torch.IntTensor(cu_seqlens) + cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) + seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] + max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + processed_batch = { 'tokens': torch.LongTensor(input_ids), 'labels': torch.LongTensor(labels), @@ -562,6 +568,8 @@ def collate_fn(self, batch): 'position_ids': torch.LongTensor(position_ids), 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 'token_count': token_count, + 'cu_seqlens_argmin': cu_seqlens_argmin, + 'max_seqlen': max_seqlen, } return processed_batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 182a3ee0516ed..9ebcd9d647c1b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -870,12 +870,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ # Transfer needed data to GPU required_keys = set() + max_seqlen, cu_seqlens_argmin = None, None if parallel_state.get_pipeline_model_parallel_world_size() == 1: required_keys.update(batch.keys()) else: required_keys.add('attention_mask') if 'cu_seqlens' in batch: required_keys.add('cu_seqlens') + if 'max_seqlen' in batch: + max_seqlen = batch['max_seqlen'].squeeze() + if 'cu_seqlens_argmin' in batch: + cu_seqlens_argmin = batch['cu_seqlens_argmin'] if parallel_state.is_pipeline_first_stage(): required_keys.update(('tokens', 'position_ids')) if parallel_state.is_pipeline_last_stage(): @@ -906,9 +911,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset # these args are passed eventually into TEDotProductAttention.forward() cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) - cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] # remove -1 "paddings" added in collate_fn + # remove -1 "paddings" added in collate_fn + if cu_seqlens_argmin is not None: + cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] + else: + cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] forward_args['cu_seqlens_q'] = cu_seqlens forward_args['cu_seqlens_kv'] = cu_seqlens + if max_seqlen is not None: + forward_args['max_seqlen_q'] = max_seqlen + forward_args['max_seqlen_kv'] = max_seqlen forward_args['qkv_format'] = 'thd' output_tensor = model(**forward_args)