Skip to content

Commit

Permalink
Update on "[WIP] Add DiLoCo"
Browse files Browse the repository at this point in the history
Still WIP but open to feedback on the API

## API Usage

```python
# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        optimizer.step()
        
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
    m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
    m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        inner_optimizer.step()
        # outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user

```

## Changes
- Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum
- Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
- TODO: should be working, but still validating some tests

discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk

[ghstack-poisoned]
  • Loading branch information
H-Huang committed Jan 29, 2025
2 parents 9f8f576 + 3dc8b87 commit 5fd4430
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 53 deletions.
33 changes: 11 additions & 22 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ def __enter__(self) -> "LocalSGD":
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
)
# Register a forward prehook to check for quorum
self._hooks.append(
self._model.register_forward_pre_hook(self._forward_step_pre_hook)
)
return self

def __exit__(
Expand Down Expand Up @@ -132,7 +128,7 @@ def _restore_parameters(self) -> None:
with torch.no_grad():
# TODO: consider running copy on a separate stream
for name, p in self._model.named_parameters():
p.copy_(self._backup_parameters[name], non_blocking=False)
p.data.copy_(self._backup_parameters[name], non_blocking=False)

def _step_post_hook(
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
Expand All @@ -144,25 +140,12 @@ def _step_post_hook(
if self._local_step >= self._sync_every:
self.sync()

def _forward_step_pre_hook(self, _module: nn.Module, _args: List[object]) -> None:
"""
Start the quorum before each module forward.
"""
if self._local_step == 0:
self._manager.start_quorum()

def sync(self) -> None:
"""
Synchronizes and averages the model weights across the manager.
"""
self._manager.start_quorum()
self._perform_sync()

if self._manager.should_commit():
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

self._local_step = 0

def _perform_sync(self) -> None:
Expand All @@ -172,6 +155,11 @@ def _perform_sync(self) -> None:
synchronization logic.
"""
self._average()
if self._manager.should_commit():
self._save_parameters()
else:
# commit failed, restore from the backup parameters
self._restore_parameters()

def _average(self) -> None:
# TODO: do we need to broadcast buffers like DDP does?
Expand Down Expand Up @@ -227,12 +215,13 @@ def _perform_sync(self) -> None:
p.grad = pseudogradient

self._average_grads()

# Restore the parameters back to the previous state
self._restore_parameters()

# Use the outer optimizer to update the model parameters
self._outer_optimizer.step()
if self._manager.should_commit():
# Use the outer optimizer to update the model parameters
self._outer_optimizer.step()
self._save_parameters()
self._outer_optimizer.zero_grad()

def _average_grads(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_local_sgd_recovery(self) -> None:


class DiLoCoTest(TestCase):
def test_diloco_healt(self) -> None:
def test_diloco_healthy(self) -> None:
model = SimpleModel()

# Setup optimizers
Expand All @@ -112,6 +112,7 @@ def test_diloco_healt(self) -> None:
)

manager = create_autospec(Manager)
manager._use_async_quorum = False
with DiLoCo(
manager, model, inner_optimizer, outer_optimizer, sync_every=2
) as diloco:
Expand Down
87 changes: 57 additions & 30 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import copy
import logging
import threading
import time
from concurrent.futures import as_completed, ThreadPoolExecutor
from contextlib import contextmanager, ExitStack
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Dict, Generator, List, Optional, Protocol, Set, Tuple
from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, Union
from unittest import TestCase

import torch
import torch.distributed as dist
from parameterized import parameterized
from torch import nn, optim
from torch import Tensor, nn, optim

from torchft.ddp import DistributedDataParallel
from torchft.local_sgd import DiLoCo, LocalSGD
Expand Down Expand Up @@ -76,6 +77,7 @@ class Runner:
world_size: int = 1
attempts: int = 3
manager_args: Dict[str, object] = field(default_factory=dict)
train_loop_args: Dict[str, Any] = field(default_factory=dict)

def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
store = dist.TCPStore(
Expand Down Expand Up @@ -103,7 +105,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
try:
fut.result()
except Exception as e:
logger.exception(f"worker threw exception: {e}")
logger.exception(f"worker {self.replica_id=} threw exception: {e}")
raise

return [fut.result() for fut in futures]
Expand Down Expand Up @@ -257,27 +259,31 @@ def diloco_train_loop(
runner: Runner,
) -> Dict[str, Dict[str, object]]:
with ExitStack() as stack:
torch.manual_seed(42)

# Declare the model and optimizers
m: nn.Module = MyModel()
model_state_dict: Dict[str, Any] = runner.train_loop_args["model_state_dict"]
m.load_state_dict(model_state_dict)

# Setup optimizers
inner_optimizer: optim.Optimizer = torch.optim.AdamW(
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
outer_optimizer: optim.Optimizer = torch.optim.SGD(
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)

# pyre-ignore[53]
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
m.load_state_dict(state_dict["model"])
# TODO: make this cleaner so we don't have to save this
diloco._backup_parameters = state_dict["backup_params"]
inner_optimizer.load_state_dict(state_dict["inner_optim"])
outer_optimizer.load_state_dict(state_dict["outer_optim"])

def state_dict() -> Dict[str, Dict[str, object]]:
def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
return {
"model": m.state_dict(),
"backup_params": copy.deepcopy(diloco._backup_parameters),
"inner_optim": inner_optimizer.state_dict(),
"outer_optim": outer_optimizer.state_dict(),
}
Expand All @@ -303,14 +309,8 @@ def state_dict() -> Dict[str, Dict[str, object]]:
)
stack.callback(manager.shutdown)

# TODO: where in the training loop should we do this?
# Ensure all models have the same starting state
# We set manual seed so the models start with the same weights
manager.start_quorum()
for param in m.parameters():
manager.allreduce(param.data)

criterion = nn.CrossEntropyLoss()
all_state_dicts = {}
with DiLoCo(
manager, m, inner_optimizer, outer_optimizer, sync_every=2
) as diloco:
Expand All @@ -324,16 +324,17 @@ def state_dict() -> Dict[str, Dict[str, object]]:
inner_optimizer.zero_grad()
loss.backward()
inner_optimizer.step()
manager_step_str = str(manager.current_step())
all_state_dicts[manager_step_str] = state_dict()

# after 4 model updates then break
if manager.current_step() >= 4:
break

runner.failure_injector.check(rank, manager.current_step())

return_state_dict = state_dict()
# return state_dict so we can check consistency
return return_state_dict
return all_state_dicts


class ManagerIntegTest(TestCase):
Expand Down Expand Up @@ -524,6 +525,11 @@ def test_diloco_healthy(self) -> None:
num_replicas = 2
futures = []

torch.manual_seed(42)
# Initialize the model so we can pass in the state_dict
m: nn.Module = MyModel()
print(m.state_dict())

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
failure_injector = FailureInjector()
Expand All @@ -532,6 +538,9 @@ def test_diloco_healthy(self) -> None:
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
train_loop=diloco_train_loop,
train_loop_args={
"model_state_dict": m.state_dict(),
},
)
futures.append(executor.submit(runner.run_replica))

Expand All @@ -542,12 +551,16 @@ def test_diloco_healthy(self) -> None:

lighthouse.shutdown()

for state_dict in state_dicts:
# inner optimizer will be different, outer optimizer and model should be the same
torch.testing.assert_close(state_dict["model"], state_dicts[0]["model"])
torch.testing.assert_close(
state_dict["outer_optim"], state_dicts[0]["outer_optim"]
)
for replica_group in state_dicts:
for step, state_dict in replica_group.items():
# inner optimizer will be different, outer optimizer and model should be the same
torch.testing.assert_close(
state_dict["backup_params"],
state_dicts[0][str(step)]["backup_params"],
)
torch.testing.assert_close(
state_dict["outer_optim"], state_dicts[0][str(step)]["outer_optim"]
)

def test_diloco_recovery(self) -> None:
lighthouse = Lighthouse(
Expand All @@ -562,6 +575,10 @@ def test_diloco_recovery(self) -> None:
FailureInjector().fail_at(0, 2),
]

torch.manual_seed(42)
# Initialize the model so we can pass in the state_dict
m: nn.Module = MyModel()

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id, failure_injector in zip(
range(num_replicas), failure_injectors
Expand All @@ -571,6 +588,9 @@ def test_diloco_recovery(self) -> None:
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
train_loop=diloco_train_loop,
train_loop_args={
"model_state_dict": m.state_dict(),
},
)
futures.append(executor.submit(runner.run_replica))

Expand All @@ -584,12 +604,19 @@ def test_diloco_recovery(self) -> None:
raise

lighthouse.shutdown()
# for state_dict in state_dicts:
# # inner optimizer will be different, outer optimizer and model should be the same
# torch.testing.assert_close(state_dict["model"], state_dicts[0]["model"])
# torch.testing.assert_close(
# state_dict["outer_optim"], state_dicts[0]["outer_optim"]
# )
for replica_group in state_dicts:
for step, state_dict in replica_group.items():
str_step = str(step)
if str_step in state_dicts[0]:
# inner optimizer will be different, outer optimizer and model should be the same
torch.testing.assert_close(
state_dict["backup_params"],
state_dicts[0][str_step]["backup_params"],
)
torch.testing.assert_close(
state_dict["outer_optim"],
state_dicts[0][str_step]["outer_optim"],
)

self.assertEqual(failure_injectors[1].count, 1)

Expand Down

0 comments on commit 5fd4430

Please sign in to comment.