Skip to content

Commit 377e8cf

Browse files
Zheng-Yong (Arsa) Angfacebook-github-bot
authored andcommitted
enable gradient accumulation in SDD (#3462)
Summary: Context: gradient accumulation is still not available in TorchRec, especially for the SDD pipeline which is being used by many recommendation models. This diff: implements a UI enabling gradient accumulation for SDD. Differential Revision: D84915986
1 parent bfcbd1e commit 377e8cf

File tree

1 file changed

+62
-26
lines changed

1 file changed

+62
-26
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
418418
(applicable to 2D sharding only)
419419
if set and DMP collection is enabled for 2D sharding,
420420
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
421+
gradient_accumulation_steps (int): number of steps to accumulate gradients before
422+
performing backward pass and optimizer update. Default is 1 (no accumulation).
423+
should_scale_losses (bool): whether to scale accumulated losses by
424+
gradient_accumulation_steps. Default is False.
421425
"""
422426

423427
# The PipelinedForward class that is used in _rewrite_model
@@ -438,6 +442,8 @@ def __init__(
438442
] = None,
439443
dmp_collection_sync_interval_batches: Optional[int] = 1,
440444
enqueue_batch_after_forward: bool = False,
445+
gradient_accumulation_steps: int = 1,
446+
should_scale_losses: bool = False,
441447
) -> None:
442448
self._model = model
443449
self._optimizer = optimizer
@@ -503,6 +509,11 @@ def __init__(
503509
dmp_collection_sync_interval_batches
504510
)
505511

512+
self._accumulation_steps: int = gradient_accumulation_steps
513+
self._accumulation_step_count: int = gradient_accumulation_steps - 1
514+
self._should_scale_losses: bool = should_scale_losses
515+
self._is_first_step: bool = True
516+
506517
if self._dmp_collection_sync_interval_batches is not None:
507518
logger.info(
508519
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
@@ -680,7 +691,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
680691
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
681692
self._set_module_context(self.contexts[0])
682693

683-
if self._model.training:
694+
# only zero grad at the start of each accumulation
695+
if self._model.training and (
696+
self._is_first_step or self._accumulation_step_count == 0
697+
):
684698
with record_function("## zero_grad ##"):
685699
self._optimizer.zero_grad()
686700

@@ -696,35 +710,57 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
696710
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
697711
self.enqueue_batch(dataloader_iter)
698712

699-
# forward
700-
with record_function(f"## forward {self.contexts[0].index} ##"):
701-
self._state = PipelineState.CALL_FWD
702-
losses, output = self._model_fwd(self.batches[0])
713+
# NOTE: the first step cannot be no_sync when DDP.static_graph = True,
714+
# due to an unfortunate restriction in torch.distributed
715+
no_sync = not self._is_first_step and (
716+
self._model.training
717+
and self._accumulation_step_count + 1 < self._accumulation_steps
718+
)
719+
with (
720+
self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16]
721+
if no_sync
722+
else contextlib.nullcontext()
723+
):
724+
# forward
725+
with record_function(f"## forward {self.contexts[0].index} ##"):
726+
self._state = PipelineState.CALL_FWD
727+
losses, output = self._model_fwd(self.batches[0])
703728

704-
if self._enqueue_batch_after_forward:
705-
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
706-
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
707-
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
708-
self.enqueue_batch(dataloader_iter)
729+
if self._enqueue_batch_after_forward:
730+
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
731+
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
732+
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
733+
self.enqueue_batch(dataloader_iter)
709734

710-
if len(self.batches) >= 2:
711-
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
712-
self.wait_sparse_data_dist(self.contexts[1])
735+
if len(self.batches) >= 2:
736+
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
737+
self.wait_sparse_data_dist(self.contexts[1])
713738

714-
if self._model.training:
715739
# backward
716-
self._state = PipelineState.CALL_BWD
717-
self._backward(losses)
718-
719-
self.sync_embeddings(
720-
self._model,
721-
self._dmp_collection_sync_interval_batches,
722-
self.contexts[0],
723-
)
724-
725-
# update
726-
with record_function(f"## optimizer {self.contexts[0].index} ##"):
727-
self._optimizer.step()
740+
if self._model.training:
741+
self._state = PipelineState.CALL_BWD
742+
if (
743+
self._should_scale_losses
744+
and self._accumulation_steps > 1
745+
and not self._is_first_step
746+
):
747+
losses = losses / self._accumulation_steps
748+
self._backward(losses)
749+
750+
if no_sync:
751+
self._accumulation_step_count += 1
752+
else:
753+
self.sync_embeddings(
754+
self._model,
755+
self._dmp_collection_sync_interval_batches,
756+
self.contexts[0],
757+
)
758+
# update
759+
with record_function(f"## optimizer {self.contexts[0].index} ##"):
760+
self._optimizer.step()
761+
self._accumulation_step_count = 0
762+
if self._is_first_step:
763+
self._is_first_step = False
728764

729765
self.dequeue_batch()
730766
return output

0 commit comments

Comments
 (0)