@@ -189,6 +189,15 @@ def __init__(self, cfg: DictConfig) -> None:
189189 self ._resume_from_checkpoint = cfg .resume_from_checkpoint
190190 self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
191191 self ._optimizer_in_bwd = cfg .get ("optimizer_in_bwd" , False )
192+
193+ # Should we raise an error rather than performing these silent checks like with `_optimizer_in_bwd` and `_clip_grad_norm`?
194+ self ._minimize_all_reduces = (
195+ cfg .get ("minimize_all_reduces" , False )
196+ and self ._gradient_accumulation_steps > 1
197+ and not self ._optimizer_in_bwd
198+ and self .parallel_dims .dp_enabled
199+ )
200+
192201 self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
193202 self ._checkpoint_client = CheckpointClient (cfg )
194203
@@ -834,6 +843,12 @@ def train(self) -> None:
834843 # We multiply by world_size to undo FSDP2 gradient normalization.
835844 current_loss = current_loss * (self .world_size / num_tokens )
836845
846+ if self ._minimize_all_reduces and (
847+ (idx + 1 ) % self ._gradient_accumulation_steps == 0
848+ ):
849+ self ._model .set_is_last_backward (True )
850+ self ._model .set_requires_all_reduce (True )
851+
837852 current_loss .backward ()
838853
839854 # Step with optimizer
@@ -857,6 +872,10 @@ def train(self) -> None:
857872 self ._optimizer .step ()
858873 self ._optimizer .zero_grad (set_to_none = True )
859874
875+ if self ._minimize_all_reduces :
876+ self ._model .set_is_last_backward (False )
877+ self ._model .set_requires_all_reduce (False )
878+
860879 # Update the number of steps when the weights are updated
861880 self .global_step += 1
862881
0 commit comments