Skip to content

Commit 2a10ef2

Browse files
authored
enable merging parameters for diloco (#212)
Summary: - merge local and global parameters of the model after synchronization - add the "alpha" parameter to integration tests Test Plan: ``` pytest -vs ./torchft/local_sgd_integ_test.py ```
1 parent 1682257 commit 2a10ef2

File tree

2 files changed

+40
-10
lines changed

2 files changed

+40
-10
lines changed

torchft/local_sgd.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,14 @@ def __init__(
213213
self.should_quantize = should_quantize
214214

215215
self._grads: Dict[str, torch.Tensor] = {}
216+
217+
# Used to save global parameters so that they can be restored in case
218+
# commit fails
216219
self.original_parameters: Dict[str, torch.Tensor] = {}
217220

221+
# Used to mix the local and global parameters
222+
self._local_parameters: Dict[str, torch.Tensor] = {}
223+
218224
for name, p in self._model_fragment.named_parameters():
219225
if isinstance(p, DTensor):
220226
p = extract_local_tensor(p.data)
@@ -237,6 +243,14 @@ def save_parameters(self) -> None:
237243
param_to_local = extract_local_tensor(p.data)
238244
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
239245

246+
def _save_local_parameters(self) -> None:
247+
"""
248+
Saves a copy of the model's parameters.
249+
"""
250+
with torch.no_grad():
251+
for name, p in self._model_fragment.named_parameters():
252+
self._local_parameters[name] = extract_local_tensor(p.data)
253+
240254
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
241255
def restore_parameters(self) -> None:
242256
with torch.no_grad():
@@ -293,6 +307,19 @@ def _set_grads(self) -> None:
293307
# No longer needed
294308
del self._grads[name]
295309

310+
def _clear_local_parameters(self) -> None:
311+
"""
312+
Clears the saved copy of the model's parameters
313+
"""
314+
self._local_parameters = {}
315+
316+
def _merge_parameters(self) -> None:
317+
"""
318+
Merges the local and global parameters.
319+
"""
320+
for name, p in self._model_fragment.named_parameters():
321+
p.data.lerp(self._local_parameters[name], 1 - self._fragment_update_alpha)
322+
296323
@torch.profiler.record_function("torchft::local_sgd::wait")
297324
def wait(self) -> None:
298325
"""
@@ -382,6 +409,8 @@ def perform_sync(self) -> bool:
382409

383410
self.wait()
384411

412+
# save the parameters so they can be used for merging
413+
self._save_local_parameters()
385414
# Restore the parameters back to the previous state
386415
self.restore_parameters()
387416

@@ -404,8 +433,12 @@ def perform_sync(self) -> bool:
404433
self._set_grads()
405434
self._outer_optimizer.step()
406435
self.save_parameters()
436+
self._merge_parameters()
407437
self._outer_optimizer.zero_grad()
408438

439+
# free up memory
440+
self._clear_local_parameters()
441+
409442
return should_commit
410443

411444
def _average_grads(self) -> None:
@@ -557,12 +590,6 @@ def __init__(
557590
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
558591
raise ValueError("fragment_update_alpha must be between 0 and 1")
559592

560-
# TODO: Support `fragment_update_alpha`
561-
if fragment_update_alpha != 0.0:
562-
raise ValueError(
563-
"Merging local parameters with global parameters is not supported yet"
564-
)
565-
566593
super().__init__()
567594
self._manager = manager
568595

torchft/local_sgd_integ_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -589,18 +589,19 @@ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
589589

590590
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
591591

592-
CONFIG: list[tuple[bool, int, int]] = [
593-
(use_cuda, n_fragments, fragment_sync_delay)
592+
CONFIG: list[tuple[bool, int, int, float]] = [
593+
(use_cuda, n_fragments, fragment_sync_delay, alpha)
594594
for use_cuda in [False]
595595
for n_fragments in [1, 2]
596596
for fragment_sync_delay in [0, 1]
597+
for alpha in [0.0, 0.5, 1.0]
597598
]
598599

599600
# pyre-fixme[56]: Pyre was not able to infer the type of argument
600601
@skipIf(sys.platform == "darwin", "not reliable on mac")
601602
@parameterized.expand(CONFIG)
602603
def test_streaming_diloco_upscale(
603-
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int
604+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float
604605
) -> None:
605606
# Skip the test if use_cuda is True and there are not enough GPUs
606607
if use_cuda and torch.cuda.device_count() < 2:
@@ -642,6 +643,7 @@ def test_streaming_diloco_upscale(
642643
"diloco_args": {
643644
"fragment_sync_delay": fragment_sync_delay,
644645
"sync_every": 4,
646+
"fragment_update_alpha": alpha,
645647
},
646648
},
647649
)
@@ -681,7 +683,7 @@ def test_streaming_diloco_upscale(
681683
@skipIf(sys.platform == "darwin", "not reliable on mac")
682684
@parameterized.expand(CONFIG)
683685
def test_streaming_diloco_commit_failure(
684-
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int
686+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float
685687
) -> None:
686688
# Skip the test if use_cuda is True and there are not enough GPUs
687689
if use_cuda and torch.cuda.device_count() < 2:
@@ -719,6 +721,7 @@ def test_streaming_diloco_commit_failure(
719721
"diloco_args": {
720722
"fragment_sync_delay": fragment_sync_delay,
721723
"sync_every": 4,
724+
"fragment_update_alpha": alpha,
722725
},
723726
},
724727
)

0 commit comments

Comments
 (0)