Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sequence parallel with communication overlap #5691

Merged
merged 20 commits into from
Aug 1, 2024

Conversation

inkcherry
Copy link
Contributor

@inkcherry inkcherry commented Jun 21, 2024

SP is a fantastic piece of work, it is very elegant and concise, at the current stage, a transformer layer's forward and backward passes involve 8 all-to-all operations, with 5 opportunities for overlapping communication:

Forward pass: The QKV matrix operations can be pipelined alongside some of the all-to-all communications.
Backward pass: DQ, DK, DV all-to-all communications can be pipelined alongside matrix operations.
Backward pass: DO_w can be parallel with DO_input, involving matrix operations and all-to-all communications. Similar overlap-comm strategies are used in Megatron for TP/TP-sp parallelism.
I tested under conditions of 1N8C zero1, disabled activation checkpointing, ds-sp=8, and gbs=16:
1B 64K
7B 16K
They showed over 10% improvement (where I found that for mega-ds, using split QKV itself can also enhance performance due to reducing slice + cat operations in fwd/bwd), despite some TFLOPs already performing at a relatively good level.
co-work with microsoft/Megatron-DeepSpeed#415

@tjruwase tjruwase requested review from samadejacobs and tohtana and removed request for mrwyattii June 21, 2024 22:17
@Edenzzzz
Copy link

overlapping only happens when computation doesn't depend on communication?

@inkcherry
Copy link
Contributor Author

inkcherry commented Jul 5, 2024

overlapping only happens when computation doesn't depend on communication?

@Edenzzzz Yes, manual sync of some dependencies is required。

@inkcherry
Copy link
Contributor Author

inkcherry commented Jul 10, 2024

we set gbs=2 ,sp=4, seq_len=16K,model size =1B, zero_stage=1, disable=activation_checkpoint, use-flash-attn-v2.

  • without this patch
  • with this patch, enable splitqkv+sp-overlap-comm
  • with this path, disable splitqkv+sp-overlap-comm

we list the loss curve & grad norm curve and they are consistent.
image
image

@samadejacobs
Copy link
Contributor

@inkcherry, many thanks for this excellent contribution to DeepSpeed codebase. To help with our review, could you please add (1) unit test(s) and (2) numbers on parallel performance improvements (throughput and latency) to the pull request? Your continuous and remarkable contributions to DeepSpeed are appreciated.

@loadams loadams merged commit 17ed7c7 into microsoft:master Aug 1, 2024
11 checks passed
@Edenzzzz
Copy link

@inkcherry Thanks for your insight! Can I ask why we need sp_stream here, as it seems to be never used, e.g. by torch.cuda.stream(sp_stream)?

@inkcherry
Copy link
Contributor Author

inkcherry commented Aug 30, 2024

@inkcherry Thanks for your insight! Can I ask why we need sp_stream here, as it seems to be never used, e.g. by torch.cuda.stream(sp_stream)?

hi @Edenzzzz
apology for missing your comments, I noticed that DeepSpeed's sequence parallel is designed in a modular way, which means we can't freely insert communication calls (comm) anywhere we want to use async_op. When pytorch computation kernel launched before communication one, the communication one will automatically sync with the default stream, we need to use a custom stream or even an event to achieve parallelism between computation and communication. It's also crucial to maintain the dependencies between them properly. The stream setup for this is in Megatron-DeepSpeed implementation.
Here are two of the three cases mentioned in this PR that fall into this category, using an additional stream. The other case uses async_op=True with all2all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants