From b7161ca9f5e10bd1bd8cd04119c6bf5eff00b306 Mon Sep 17 00:00:00 2001 From: "Zheng-Yong (Arsa) Ang" Date: Mon, 27 Oct 2025 02:17:40 -0700 Subject: [PATCH] enable gradient accumulation in SDD (#3462) Summary: Context: gradient accumulation is still not available in TorchRec, especially for the SDD pipeline which is being used by many recommendation models. This diff: implements a UI enabling gradient accumulation for SDD. Differential Revision: D84915986 --- .../train_pipeline/train_pipelines.py | 76 +++++++++++++------ 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index b90a27242..44d87c344 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -418,6 +418,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): (applicable to 2D sharding only) if set and DMP collection is enabled for 2D sharding, sync DMPs every N batches (default to 1, i.e. every batch, None to disable) + gradient_accumulation_steps (int): number of steps to accumulate gradients before + performing backward pass and optimizer update. Default is 1 (no accumulation). """ # The PipelinedForward class that is used in _rewrite_model @@ -438,6 +440,7 @@ def __init__( ] = None, dmp_collection_sync_interval_batches: Optional[int] = 1, enqueue_batch_after_forward: bool = False, + gradient_accumulation_steps: int = 1, ) -> None: self._model = model self._optimizer = optimizer @@ -503,6 +506,10 @@ def __init__( dmp_collection_sync_interval_batches ) + self._accumulation_steps: int = gradient_accumulation_steps + self._accumulation_step_count: int = gradient_accumulation_steps - 1 + self._is_first_step: bool = True + if self._dmp_collection_sync_interval_batches is not None: logger.info( f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every " @@ -680,7 +687,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) self._set_module_context(self.contexts[0]) - if self._model.training: + # only zero grad at the start of each accumulation + if self._model.training and ( + self._is_first_step or self._accumulation_step_count == 0 + ): with record_function("## zero_grad ##"): self._optimizer.zero_grad() @@ -696,35 +706,51 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here self.enqueue_batch(dataloader_iter) - # forward - with record_function(f"## forward {self.contexts[0].index} ##"): - self._state = PipelineState.CALL_FWD - losses, output = self._model_fwd(self.batches[0]) + # NOTE: the first step cannot be no_sync when DDP.static_graph = True, + # due to an unfortunate restriction in torch.distributed + no_sync = not self._is_first_step and ( + self._model.training + and self._accumulation_step_count + 1 < self._accumulation_steps + ) + with ( + self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16] + if no_sync + else contextlib.nullcontext() + ): + # forward + with record_function(f"## forward {self.contexts[0].index} ##"): + self._state = PipelineState.CALL_FWD + losses, output = self._model_fwd(self.batches[0]) - if self._enqueue_batch_after_forward: - # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. - # Start this step after the forward of batch i, so that the H2D copy doesn't compete - # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. - self.enqueue_batch(dataloader_iter) + if self._enqueue_batch_after_forward: + # batch i+2: load data and copy to gpu, the dataload iter will first exhaust here. + # Start this step after the forward of batch i, so that the H2D copy doesn't compete + # for pcie bandwidth with embedding lookup from UVM/UVM_CACHING. + self.enqueue_batch(dataloader_iter) - if len(self.batches) >= 2: - # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) - self.wait_sparse_data_dist(self.contexts[1]) + if len(self.batches) >= 2: + # invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist) + self.wait_sparse_data_dist(self.contexts[1]) - if self._model.training: # backward - self._state = PipelineState.CALL_BWD - self._backward(losses) - - self.sync_embeddings( - self._model, - self._dmp_collection_sync_interval_batches, - self.contexts[0], - ) + if self._model.training: + self._state = PipelineState.CALL_BWD + self._backward(losses) - # update - with record_function(f"## optimizer {self.contexts[0].index} ##"): - self._optimizer.step() + if no_sync: + self._accumulation_step_count += 1 + else: + self.sync_embeddings( + self._model, + self._dmp_collection_sync_interval_batches, + self.contexts[0], + ) + # update + with record_function(f"## optimizer {self.contexts[0].index} ##"): + self._optimizer.step() + self._accumulation_step_count = 0 + if self._is_first_step: + self._is_first_step = False self.dequeue_batch() return output