Skip to content

Commit c85c870

Browse files
tohtanatjruwaseloadams
authored
Fix gradient accumulation for Z2+offload (#6550)
The ZeRO 1/2 optimizer performs incorrect gradient accumulation in the path for ZeRO2 + Offloading. This issue is caused by two main reasons: 1) The micro_step_id in the ZeRO 1/2 optimizer is: - Initialized to 0 in the constructor. - Reset to -1 during the backward pass. For example, given a gradient accumulation step of 4, the micro_step_id changes as follows: - For the first global step: 1, 2, 3, 4. - Subsequently: 0, 1, 2, 3. 2) Gradients are copied to the buffer on the first micro step and accumulated in the buffer during the following micro steps. However, the current code incorrectly copies gradients at steps that are not at the accumulation boundary. This PR aligns the micro_step_id initialization in both the constructor and the backward pass, and corrects the condition for copying and accumulating gradients. Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
1 parent 0fbe96a commit c85c870

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
OPTIMIZER_GRADIENTS_TIMER = 'optimizer_gradients'
4040
OPTIMIZER_STEP_TIMER = 'optimizer_step'
4141
OPTIMIZER_TIMERS = [OPTIMIZER_ALLGATHER_TIMER, OPTIMIZER_GRADIENTS_TIMER, OPTIMIZER_STEP_TIMER]
42+
INITIAL_MICRO_STEP_ID = -1
4243

4344

4445
def input(msg):
@@ -224,7 +225,7 @@ def __init__(self,
224225
self.gradient_predivide_factor = gradient_predivide_factor
225226
self.postscale_gradients = postscale_gradients
226227
self.gradient_accumulation_steps = gradient_accumulation_steps
227-
self.micro_step_id = 0
228+
self.micro_step_id = INITIAL_MICRO_STEP_ID
228229
self.ignore_unused_parameters = ignore_unused_parameters
229230
self.round_robin_gradients = round_robin_gradients
230231

@@ -1231,9 +1232,7 @@ def copy_gradients_to_cpu():
12311232

12321233
if self.micro_step_id > 0:
12331234
accumulate_gradients()
1234-
1235-
# at the boundary we will send 32bit directly
1236-
if not self.is_gradient_accumulation_boundary:
1235+
else:
12371236
copy_gradients_to_cpu()
12381237

12391238
def set_norm_for_param_grad(self, param):
@@ -1824,7 +1823,7 @@ def step(self, closure=None):
18241823
"""
18251824
Not supporting closure.
18261825
"""
1827-
self.micro_step_id = -1
1826+
self.micro_step_id = INITIAL_MICRO_STEP_ID
18281827

18291828
see_memory_usage(f"In step before checking overflow")
18301829

0 commit comments

Comments
 (0)