@@ -418,6 +418,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418            (applicable to 2D sharding only) 
419419            if set and DMP collection is enabled for 2D sharding, 
420420            sync DMPs every N batches (default to 1, i.e. every batch, None to disable) 
421+         gradient_accumulation_steps (int): number of steps to accumulate gradients before 
422+             performing backward pass and optimizer update. Default is 1 (no accumulation). 
423+         should_scale_losses (bool): whether to scale accumulated losses by 
424+             gradient_accumulation_steps. Default is False. 
421425    """ 
422426
423427    # The PipelinedForward class that is used in _rewrite_model 
@@ -438,6 +442,8 @@ def __init__(
438442        ] =  None ,
439443        dmp_collection_sync_interval_batches : Optional [int ] =  1 ,
440444        enqueue_batch_after_forward : bool  =  False ,
445+         gradient_accumulation_steps : int  =  1 ,
446+         should_scale_losses : bool  =  False ,
441447    ) ->  None :
442448        self ._model  =  model 
443449        self ._optimizer  =  optimizer 
@@ -503,6 +509,11 @@ def __init__(
503509            dmp_collection_sync_interval_batches 
504510        )
505511
512+         self ._accumulation_steps : int  =  gradient_accumulation_steps 
513+         self ._accumulation_step_count : int  =  gradient_accumulation_steps  -  1 
514+         self ._should_scale_losses : bool  =  should_scale_losses 
515+         self ._is_first_step : bool  =  True 
516+ 
506517        if  self ._dmp_collection_sync_interval_batches  is  not None :
507518            logger .info (
508519                f"{ self .__class__ .__name__ }  
@@ -680,7 +691,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691        # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) 
681692        self ._set_module_context (self .contexts [0 ])
682693
683-         if  self ._model .training :
694+         # only zero grad at the start of each accumulation 
695+         if  self ._model .training  and  (
696+             self ._is_first_step  or  self ._accumulation_step_count  ==  0 
697+         ):
684698            with  record_function ("## zero_grad ##" ):
685699                self ._optimizer .zero_grad ()
686700
@@ -696,35 +710,57 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696710            # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here 
697711            self .enqueue_batch (dataloader_iter )
698712
699-         # forward 
700-         with  record_function (f"## forward { self .contexts [0 ].index }  ):
701-             self ._state  =  PipelineState .CALL_FWD 
702-             losses , output  =  self ._model_fwd (self .batches [0 ])
713+         # NOTE: the first step cannot be no_sync when DDP.static_graph = True, 
714+         #       due to an unfortunate restriction in torch.distributed 
715+         no_sync  =  not  self ._is_first_step  and  (
716+             self ._model .training 
717+             and  self ._accumulation_step_count  +  1  <  self ._accumulation_steps 
718+         )
719+         with  (
720+             self ._model ._dmp_wrapped_module .no_sync ()  # pyre-ignore[16] 
721+             if  no_sync 
722+             else  contextlib .nullcontext ()
723+         ):
724+             # forward 
725+             with  record_function (f"## forward { self .contexts [0 ].index }  ):
726+                 self ._state  =  PipelineState .CALL_FWD 
727+                 losses , output  =  self ._model_fwd (self .batches [0 ])
703728
704-         if  self ._enqueue_batch_after_forward :
705-             # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. 
706-             # Start this step after the forward of batch i, so that the H2D copy doesn't compete 
707-             # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. 
708-             self .enqueue_batch (dataloader_iter )
729+              if  self ._enqueue_batch_after_forward :
730+                  # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. 
731+                  # Start this step after the forward of batch i, so that the H2D copy doesn't compete 
732+                  # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. 
733+                  self .enqueue_batch (dataloader_iter )
709734
710-         if  len (self .batches ) >=  2 :
711-             # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) 
712-             self .wait_sparse_data_dist (self .contexts [1 ])
735+              if  len (self .batches ) >=  2 :
736+                  # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) 
737+                  self .wait_sparse_data_dist (self .contexts [1 ])
713738
714-         if  self ._model .training :
715739            # backward 
716-             self ._state  =  PipelineState .CALL_BWD 
717-             self ._backward (losses )
718- 
719-             self .sync_embeddings (
720-                 self ._model ,
721-                 self ._dmp_collection_sync_interval_batches ,
722-                 self .contexts [0 ],
723-             )
724- 
725-             # update 
726-             with  record_function (f"## optimizer { self .contexts [0 ].index }  ):
727-                 self ._optimizer .step ()
740+             if  self ._model .training :
741+                 self ._state  =  PipelineState .CALL_BWD 
742+                 if  (
743+                     self ._should_scale_losses 
744+                     and  self ._accumulation_steps  >  1 
745+                     and  not  self ._is_first_step 
746+                 ):
747+                     losses  =  losses  /  self ._accumulation_steps 
748+                 self ._backward (losses )
749+ 
750+             if  no_sync :
751+                 self ._accumulation_step_count  +=  1 
752+             else :
753+                 self .sync_embeddings (
754+                     self ._model ,
755+                     self ._dmp_collection_sync_interval_batches ,
756+                     self .contexts [0 ],
757+                 )
758+                 # update 
759+                 with  record_function (f"## optimizer { self .contexts [0 ].index }  ):
760+                     self ._optimizer .step ()
761+                 self ._accumulation_step_count  =  0 
762+                 if  self ._is_first_step :
763+                     self ._is_first_step  =  False 
728764
729765        self .dequeue_batch ()
730766        return  output 
0 commit comments