Skip to content

Commit

Permalink
pre-generate cu_seqlens argmin and max_seqlen to remove host-to-devic…
Browse files Browse the repository at this point in the history
…e data copies (#8108)

Signed-off-by: Sangkug Lym <slym@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
erhoo82 and ericharper authored Jan 18, 2024
1 parent 92b098a commit 6c40209
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,12 @@ def collate_fn(self, batch):
loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0)
position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0)

# Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies.
cu_seqlens = torch.IntTensor(cu_seqlens)
cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True)
seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
max_seqlen, _ = seqlens.max(dim=1, keepdim=True)

processed_batch = {
'tokens': torch.LongTensor(input_ids),
'labels': torch.LongTensor(labels),
Expand All @@ -562,6 +568,8 @@ def collate_fn(self, batch):
'position_ids': torch.LongTensor(position_ids),
'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32
'token_count': token_count,
'cu_seqlens_argmin': cu_seqlens_argmin,
'max_seqlen': max_seqlen,
}

return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -870,12 +870,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_

# Transfer needed data to GPU
required_keys = set()
max_seqlen, cu_seqlens_argmin = None, None
if parallel_state.get_pipeline_model_parallel_world_size() == 1:
required_keys.update(batch.keys())
else:
required_keys.add('attention_mask')
if 'cu_seqlens' in batch:
required_keys.add('cu_seqlens')
if 'max_seqlen' in batch:
max_seqlen = batch['max_seqlen'].squeeze()
if 'cu_seqlens_argmin' in batch:
cu_seqlens_argmin = batch['cu_seqlens_argmin']
if parallel_state.is_pipeline_first_stage():
required_keys.update(('tokens', 'position_ids'))
if parallel_state.is_pipeline_last_stage():
Expand Down Expand Up @@ -906,9 +911,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset
# these args are passed eventually into TEDotProductAttention.forward()
cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1)
cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] # remove -1 "paddings" added in collate_fn
# remove -1 "paddings" added in collate_fn
if cu_seqlens_argmin is not None:
cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
else:
cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]
forward_args['cu_seqlens_q'] = cu_seqlens
forward_args['cu_seqlens_kv'] = cu_seqlens
if max_seqlen is not None:
forward_args['max_seqlen_q'] = max_seqlen
forward_args['max_seqlen_kv'] = max_seqlen
forward_args['qkv_format'] = 'thd'

output_tensor = model(**forward_args)
Expand Down

0 comments on commit 6c40209

Please sign in to comment.