Skip to content

enable merging parameters for diloco #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 72 additions & 7 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,14 @@ def __init__(
self.should_quantize = should_quantize

self._grads: Dict[str, torch.Tensor] = {}

# Used to save global parameters so that they can be restored in case
# commit fails
self.original_parameters: Dict[str, torch.Tensor] = {}

# Used to mix the local and global parameters
self._local_parameters: Dict[str, torch.Tensor] = {}

for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
p = extract_local_tensor(p.data)
Expand All @@ -237,6 +243,14 @@ def save_parameters(self) -> None:
param_to_local = extract_local_tensor(p.data)
self.original_parameters[name].copy_(param_to_local, non_blocking=True)

def _save_local_parameters(self) -> None:
"""
Saves a copy of the model's parameters.
"""
with torch.no_grad():
for name, p in self._model_fragment.named_parameters():
self._local_parameters[name] = extract_local_tensor(p.data)

@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
def restore_parameters(self) -> None:
with torch.no_grad():
Expand Down Expand Up @@ -293,6 +307,21 @@ def _set_grads(self) -> None:
# No longer needed
del self._grads[name]

def _clear_local_parameters(self) -> None:
"""
Clears the saved copy of the model's parameters
"""
self._local_parameters = {}

def _merge_parameters(self) -> None:
"""
Merges the local and global parameters.
"""
for name, p in self._model_fragment.named_parameters():
torch.lerp(
p.data, self._local_parameters[name], 1 - self._fragment_update_alpha
)

@torch.profiler.record_function("torchft::local_sgd::wait")
def wait(self) -> None:
"""
Expand Down Expand Up @@ -357,22 +386,54 @@ def perform_sync(self) -> bool:
steps using the outer optimizer.
"""
if len(self._allreduce_futures) == 0:
return True
assert self._fragment_sync_delay > 0
# This can happen when using `fragment_sync_delay`. The node
# might not have participated in syncing of this fragment.
#
# The allreduce for other nodes who did might actually
# succeed and in that case, we shouldn't allow recovery
# from this node.
#
# We do need to increase the `max_step` here so we
# don't end up in an infinite loop of needing to recover.
#
# TODO: We can add a `is_catching_up` flag to the state_dict
# to disallow recoveries from this node. Such nodes can
# be excluded from `max_step` calculation unless all
# nodes are catching up.
return self._manager.should_commit()

self.wait()

# save the parameters so they can be used for merging
self._save_local_parameters()
# Restore the parameters back to the previous state
self.restore_parameters()

# This can return success even if the allreduce failed. Because
# the process group could have been reconfigured while the
# allreduce was inflight. The inflight allreduce may or may
# not have been aborted.
#
# We consider it successful anyway.
#
# TODO: We can track errors per allreduce to
# let the commit fail here. But this has the downside of
# reconfiguring the pg too many times resulting in
# more aborts and more commit failures.
should_commit = self._manager.should_commit()

if should_commit:
# Use the outer optimizer to update the model parameters
self._set_grads()
self._outer_optimizer.step()
self.save_parameters()
self._merge_parameters()
self._outer_optimizer.zero_grad()

# free up memory
self._clear_local_parameters()

return should_commit

def _average_grads(self) -> None:
Expand Down Expand Up @@ -524,12 +585,6 @@ def __init__(
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
raise ValueError("fragment_update_alpha must be between 0 and 1")

# TODO: Support `fragment_update_alpha`
if fragment_update_alpha != 0.0:
raise ValueError(
"Merging local parameters with global parameters is not supported yet"
)

super().__init__()
self._manager = manager

Expand Down Expand Up @@ -708,6 +763,16 @@ def _step_post_hook(
# waste after recovery
self._quorum_loop()

# TODO: Since we do quorum after commit, there might be a big gap until
# the next allreduce. This increases the chances of nodes failing
# and so the allreduce to fail.
# - We could maybe do a quorum again right before preparing for a fragment
# using `shring_only`. This might make it tricky for new nodes to join
# though.
# - Maintain a sequence number in the state dict that gets bumped at every
# quorum call. Then we can do a quorum right before allreduce and avoid
# doing quorums after commit.

# We need to set make sure `_local_step` is still
# the same across all replicas if `quorum_id` changed.
#
Expand Down
23 changes: 18 additions & 5 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,13 +414,26 @@ def test_diloco_recovery(self, use_cuda: bool) -> None:
rep0, rep1 = state_dicts

for step in rep0.keys():
# Inner optimizer will be different, outer optimizer and model should be the same
# Inner optimizer and local model parameters will be different e.g.
# with 2 replicas r1 and r2, we sync every 2 steps
#
# - Manager Step 1
# - Step 1: r1 and r2 step
# - Step 2: r1 and r2 step, sync the model, quorum succeeds
# - Manager Step 2
# - Step 1: r1 steps but r2 fails
# - Step 2:
# - r1 steps, sync fails because r2 is down
# - r1 recovers r2 from the model state at this step
# that is different from the model for r1 at the beginning
# of step Manager Step 2
#
# Outer optimizer and global model should be the same

torch.testing.assert_close(
rep1[step]["model"],
rep0[step]["model"],
rep1[step]["original_params"],
rep0[step]["original_params"],
check_device=False,
rtol=1e-4,
atol=1e-4,
)
torch.testing.assert_close(
rep1[step]["outer_optim"],
Expand Down
21 changes: 16 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def __init__(
torch.cuda.Stream() if torch.cuda.is_available() else None
)

# Used to synchronize recovery operation
self._recovery_event: Optional[torch.cuda.Event] = None

if self._group_rank == 0:
if port is None:
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
Expand Down Expand Up @@ -323,6 +326,7 @@ def allreduce(
return fut

self.wait_quorum()
num_participants: int = self.num_participants()

if not self.is_participating():
tensor.zero_()
Expand All @@ -337,6 +341,7 @@ def allreduce(
)
else:
work = self._pg.allreduce([tensor], ReduceOp.SUM)
work.wait()
fut = work.get_future()

stream: Optional[torch.cuda.Stream] = (
Expand All @@ -349,13 +354,13 @@ def allreduce(
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor, stream
nonlocal tensor, stream, num_participants

# change the stream to avoid making the callback stream
# dependent on process group stream running the allreduce
with torch.cuda.stream(stream) if stream is not None else nullcontext():
fut.value()
tensor /= self.num_participants()
tensor /= num_participants

return tensor

Expand Down Expand Up @@ -644,7 +649,12 @@ def _async_quorum(
except Exception as e:
self._logger.exception(f"got exception in recovery: {e}")
self.report_error(e)
return

self._recovery_event = (
torch.cuda.current_stream().record_event()
if recovery_stream is not None
else None
)

def _apply_pending_state_dict(self) -> None:
assert self._healing, "must be in healing state"
Expand Down Expand Up @@ -704,8 +714,9 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
with torch.profiler.record_function(
"torchft::manager::should_commmit::recovery_stream::synchronize"
):
if self._recovery_stream is not None:
self._recovery_stream.synchronize()
if self._recovery_event is not None:
self._recovery_event.synchronize()
self._recovery_event = None

with torch.profiler.record_function(
"torchft::manager::should_commit::current_stream::synchronize"
Expand Down
Loading