Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 62 additions & 26 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]):
(applicable to 2D sharding only)
if set and DMP collection is enabled for 2D sharding,
sync DMPs every N batches (default to 1, i.e. every batch, None to disable)
gradient_accumulation_steps (int): number of steps to accumulate gradients before
performing backward pass and optimizer update. Default is 1 (no accumulation).
should_scale_losses (bool): whether to scale accumulated losses by
gradient_accumulation_steps. Default is False.
"""

# The PipelinedForward class that is used in _rewrite_model
Expand All @@ -438,6 +442,8 @@ def __init__(
] = None,
dmp_collection_sync_interval_batches: Optional[int] = 1,
enqueue_batch_after_forward: bool = False,
gradient_accumulation_steps: int = 1,
should_scale_losses: bool = False,
) -> None:
self._model = model
self._optimizer = optimizer
Expand Down Expand Up @@ -503,6 +509,11 @@ def __init__(
dmp_collection_sync_interval_batches
)

self._accumulation_steps: int = gradient_accumulation_steps
self._accumulation_step_count: int = gradient_accumulation_steps - 1
self._should_scale_losses: bool = should_scale_losses
self._is_first_step: bool = True

if self._dmp_collection_sync_interval_batches is not None:
logger.info(
f"{self.__class__.__name__}: [Sparse 2D] DMP collection will sync every "
Expand Down Expand Up @@ -680,7 +691,10 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
self._set_module_context(self.contexts[0])

if self._model.training:
# only zero grad at the start of each accumulation
if self._model.training and (
self._is_first_step or self._accumulation_step_count == 0
):
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()

Expand All @@ -696,35 +710,57 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here
self.enqueue_batch(dataloader_iter)

# forward
with record_function(f"## forward {self.contexts[0].index} ##"):
self._state = PipelineState.CALL_FWD
losses, output = self._model_fwd(self.batches[0])
# NOTE: the first step cannot be no_sync when DDP.static_graph = True,
# due to an unfortunate restriction in torch.distributed
no_sync = not self._is_first_step and (
self._model.training
and self._accumulation_step_count + 1 < self._accumulation_steps
)
with (
self._model._dmp_wrapped_module.no_sync() # pyre-ignore[16]
if no_sync
else contextlib.nullcontext()
):
# forward
with record_function(f"## forward {self.contexts[0].index} ##"):
self._state = PipelineState.CALL_FWD
losses, output = self._model_fwd(self.batches[0])

if self._enqueue_batch_after_forward:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
self.enqueue_batch(dataloader_iter)
if self._enqueue_batch_after_forward:
# batch i+2: load data and copy to gpu, the dataload iter will first exhaust here.
# Start this step after the forward of batch i, so that the H2D copy doesn't compete
# for pcie bandwidth with embedding lookup from UVM/UVM_CACHING.
self.enqueue_batch(dataloader_iter)

if len(self.batches) >= 2:
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
self.wait_sparse_data_dist(self.contexts[1])
if len(self.batches) >= 2:
# invoke data (values, lengths, etc.) all_to_all comms (second part of input_dist)
self.wait_sparse_data_dist(self.contexts[1])

if self._model.training:
# backward
self._state = PipelineState.CALL_BWD
self._backward(losses)

self.sync_embeddings(
self._model,
self._dmp_collection_sync_interval_batches,
self.contexts[0],
)

# update
with record_function(f"## optimizer {self.contexts[0].index} ##"):
self._optimizer.step()
if self._model.training:
self._state = PipelineState.CALL_BWD
if (
self._should_scale_losses
and self._accumulation_steps > 1
and not self._is_first_step
):
losses = losses / self._accumulation_steps
self._backward(losses)

if no_sync:
self._accumulation_step_count += 1
else:
self.sync_embeddings(
self._model,
self._dmp_collection_sync_interval_batches,
self.contexts[0],
)
# update
with record_function(f"## optimizer {self.contexts[0].index} ##"):
self._optimizer.step()
self._accumulation_step_count = 0
if self._is_first_step:
self._is_first_step = False

self.dequeue_batch()
return output
Expand Down
Loading