Skip to content

Conversation

@tohtana
Copy link
Contributor

@tohtana tohtana commented Sep 18, 2024

DeepSpeed has several ways to call zero_grad() but they have the following inconsistency.

  • ZeRO 1/2 optimizer's zero_grad: Clear .grad and .grad_acc
  • ZeRO 3 optimizer's zero_grad: Clear .grad and reset micro_step_id. This affects whether it overwrites or accumulates gradients after reduce. It also causes a mismatch with engine's micro_steps.
  • Engine's zero_grad: Clear .grad (doesn't call optimizer's zero_grad in its zero_grad). But it calls the optimizer's zero_grad after step().

Another confusion is that it doesn't consider the gradient accumulation boundary while backward and step do. Users naturally expect the code below works, but these inconsistent behaviors can potentially cause unexpected behavior as shown in comments.

for batch in data_loader:
    # We need *if condition* to run zero_grad only at a gradient accumulation boundary
    target_engine.zero_grad() # optimizer.zero_grad() is safer but it shows different behavior with Z1/2 and 3

    outputs = target_engine(batch)
    target_engine.backward(loss)
    target_engine.step() # this is another confusion ... user can call optimizer.step() but it doesn't work in some cases

This PR aims to improve the behavior of the optimizers.

  • zero_grad clears gradients only at a gradient accumulation boundary.
    • Shows a warning once when it is called at steps that are not a gradient accumulation boundary
    • Accepts kwarg force to clear gradients
  • ZeRO 1/2/3 optimizers and engine's zero_grad have the same effect
    • Users can call either of optimizer.zero_grad and engine.zero_grad with any zero stages
    • Stop resetting micro_step_id for Z3 optimizer to make it consistent with engine's micro_steps

(This PR depends on #6550)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants