Skip to content

Comments

[zero-1] ignore overlap/contiguous_gradients flags#1246

Merged
jeffra merged 3 commits intomasterfrom
jeffra/z1-defaults
Jul 27, 2021
Merged

[zero-1] ignore overlap/contiguous_gradients flags#1246
jeffra merged 3 commits intomasterfrom
jeffra/z1-defaults

Conversation

@jeffra
Copy link
Collaborator

@jeffra jeffra commented Jul 23, 2021

  1. Overlap and contiguous grads are meaningless in stage 1 and should be ignored.
  2. Found typo'd mpi variable that would cause a crash (I guess it's not a frequently used code path).

Copy link
Contributor

@samyam samyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and makes sense to do this. Since none of the backward hooks are triggered by stage 1, there is no overlapping or contiguous gradients.

@jeffra jeffra merged commit 6ae756c into master Jul 27, 2021
@jeffra jeffra deleted the jeffra/z1-defaults branch July 27, 2021 00:13
jeffra added a commit that referenced this pull request Jul 29, 2021
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
@xiaopqr
Copy link

xiaopqr commented Jan 6, 2023

LGTM and makes sense to do this. Since none of the backward hooks are triggered by stage 1, there is no overlapping or contiguous gradients.

What should we do if we want to realize the overlapping of communication and computing in Zero1?

github-merge-queue bot pushed a commit that referenced this pull request Jan 5, 2024
…4887)

The `overlap_comm` and `contiguous_gradients` options have been ignored
in ZeRO stage 1 since #1246.
Back in that time, ZeRO 1 and 2 are separately implemented (see
https://github.com/microsoft/DeepSpeed/tree/6ae756c03f12674f17aef90622e7664a8af9d2af/deepspeed/runtime/zero).
ZeRO 1 does not have gradient hooks registered to overlap backward and
gradient all-reduce, so it's fine to ignore `overlap_comm` and
`contiguous_gradients`. However, in the current implementation, ZeRO 1
and 2 share almost the same implementation (`stage_1_and_2.py`).
Features like `overlap_comm` and `contiguous_gradients` can also be
enabled for ZeRO 1 (Please correct me if I made a mistake).

With this PR, turning on `overlap_comm` and `contiguous_gradients` for
ZeRO 1 on the [SFT
task](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning)
produces exactly the same training curve as the latest master.


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/bda3be7b-c236-4e08-b687-b3cd01f5cc73)

I also see a ~1.05x e2e speedup by overlapping backward and gradient
all-reduce. I can confirm by the trace that backward and all-reduce do
overlap, and the separate gradients are indeed copied to a flat buffer.
These options are also effective for ZeRO 1.


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/5f876296-e1b4-404b-8b33-03cee8e5e6b2)


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/9654f6be-5c7a-401a-b0bc-413ecd3f4e6b)

Related issue: #2295

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
…eepspeedai#4887)

The `overlap_comm` and `contiguous_gradients` options have been ignored
in ZeRO stage 1 since deepspeedai#1246.
Back in that time, ZeRO 1 and 2 are separately implemented (see
https://github.com/microsoft/DeepSpeed/tree/6ae756c03f12674f17aef90622e7664a8af9d2af/deepspeed/runtime/zero).
ZeRO 1 does not have gradient hooks registered to overlap backward and
gradient all-reduce, so it's fine to ignore `overlap_comm` and
`contiguous_gradients`. However, in the current implementation, ZeRO 1
and 2 share almost the same implementation (`stage_1_and_2.py`).
Features like `overlap_comm` and `contiguous_gradients` can also be
enabled for ZeRO 1 (Please correct me if I made a mistake).

With this PR, turning on `overlap_comm` and `contiguous_gradients` for
ZeRO 1 on the [SFT
task](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning)
produces exactly the same training curve as the latest master.


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/bda3be7b-c236-4e08-b687-b3cd01f5cc73)

I also see a ~1.05x e2e speedup by overlapping backward and gradient
all-reduce. I can confirm by the trace that backward and all-reduce do
overlap, and the separate gradients are indeed copied to a flat buffer.
These options are also effective for ZeRO 1.


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/5f876296-e1b4-404b-8b33-03cee8e5e6b2)


![image](https://github.com/microsoft/DeepSpeed/assets/39846316/9654f6be-5c7a-401a-b0bc-413ecd3f4e6b)

Related issue: deepspeedai#2295

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants