fix sequence parallel(Ulysses) grad scale for zero0#5555
Merged
samadejacobs merged 2 commits intodeepspeedai:masterfrom Jun 5, 2024
Merged
fix sequence parallel(Ulysses) grad scale for zero0#5555samadejacobs merged 2 commits intodeepspeedai:masterfrom
samadejacobs merged 2 commits intodeepspeedai:masterfrom
Conversation
deepspeed/runtime/engine.py
Outdated
| # to maintain the gradients value unaffected by ep_size setting, | ||
| # utilize dp_world_size for allreduce average | ||
| dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) | ||
| dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) / float(self.sequence_parallel_size) |
Contributor
There was a problem hiding this comment.
@inkcherry, can you help me understand why scale by sp_size? get_data_parallel_group != get_sequence_data_parallel_group, you should have correct value already, no?
Contributor
Author
There was a problem hiding this comment.
Thanks for the review! @samadejacobs. Yes, this should be the correct value. We should only need to modify the dp_world_size in the above instance.
Contributor
Author
|
Hi, @samadejacobs I have removed the modifications you mentioned in that line. Could you please help review the other parts again? Thanks! |
tohtana
approved these changes
Jun 5, 2024
sfc-gh-reyazda
pushed a commit
to Snowflake-Labs/DeepSpeed
that referenced
this pull request
Jun 10, 2024
use dp_world_size for grad reduction, instead of seq_dp_world_size. Currently, for zero0, only sparse tensors use the correct world_size. tiny model with sp=4 grad norm test: grad_norm | step1 | step2 | step3 | step4 |step5 | step100 -- | -- | -- | -- | -- | --| -- zero1 | 15.825 | 16.646|15.853 | 16.159 | 17.333 | 15.555 zero0 | 3.956 | 4.161 | 3.963 | 4.040 | 4.333| 3.889 zero0(this patch) | 15.825 | 16.646 | 15.853| 16.159 | 17.333 | 15.554
23 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
use dp_world_size for grad reduction, instead of seq_dp_world_size.
Currently, for zero0, only sparse tensors use the correct world_size.
tiny model with sp=4 grad norm test: