Skip to content

Commit 0be173f

Browse files
committed
Add option to minimise all reduces in HSDP
Signed-off-by: Nathan Azrak <nathan.azrak@gmail.com>
1 parent c97f72c commit 0be173f

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

recipes/full_finetune_distributed.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ def __init__(self, cfg: DictConfig) -> None:
189189
self._resume_from_checkpoint = cfg.resume_from_checkpoint
190190
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
191191
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
192+
193+
# Should we raise an error rather than performing these silent checks like with `_optimizer_in_bwd` and `_clip_grad_norm`?
194+
self._minimize_all_reduces = (
195+
cfg.get("minimize_all_reduces", False)
196+
and self._gradient_accumulation_steps > 1
197+
and not self._optimizer_in_bwd
198+
and self.parallel_dims.dp_enabled
199+
)
200+
192201
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
193202
self._checkpoint_client = CheckpointClient(cfg)
194203

@@ -834,6 +843,12 @@ def train(self) -> None:
834843
# We multiply by world_size to undo FSDP2 gradient normalization.
835844
current_loss = current_loss * (self.world_size / num_tokens)
836845

846+
if self._minimize_all_reduces and (
847+
(idx + 1) % self._gradient_accumulation_steps == 0
848+
):
849+
self._model.set_is_last_backward(True)
850+
self._model.set_requires_all_reduce(True)
851+
837852
current_loss.backward()
838853

839854
# Step with optimizer
@@ -857,6 +872,10 @@ def train(self) -> None:
857872
self._optimizer.step()
858873
self._optimizer.zero_grad(set_to_none=True)
859874

875+
if self._minimize_all_reduces:
876+
self._model.set_is_last_backward(False)
877+
self._model.set_requires_all_reduce(False)
878+
860879
# Update the number of steps when the weights are updated
861880
self.global_step += 1
862881

0 commit comments

Comments
 (0)