Skip to content

Commit

Permalink
Fixed Context Parallel HtoD sync (#8557)
Browse files Browse the repository at this point in the history
* Fixed cp HtoD sync

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: Selvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>
  • Loading branch information
3 people authored and JRD971000 committed Mar 15, 2024
1 parent 220aaa9 commit 206fe74
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,9 @@ def get_batch_on_this_context_parallel_rank(self, batch):
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ def set_input_tensor(self, input_tensor):
def get_position_embedding_on_this_context_parallel_rank(self, position_embedding, seq_dim):
cp_size = parallel_state.get_context_parallel_world_size()
cp_rank = parallel_state.get_context_parallel_rank()
cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=position_embedding.device)
cp_idx = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
position_embedding = position_embedding.view(
*position_embedding.shape[:seq_dim], 2 * cp_size, -1, *position_embedding.shape[(seq_dim + 1) :]
)
Expand Down

0 comments on commit 206fe74

Please sign in to comment.