Skip to content

Commit

Permalink
Update on "[WIP] Add DiLoCo"
Browse files Browse the repository at this point in the history
Still WIP but open to feedback on the API

## API Usage

```python
# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        optimizer.step()
        
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
    m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
    m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        inner_optimizer.step()
        # outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user

```

## Changes
- Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum
- Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
- TODO: should be working, but still validating some tests

discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk

[ghstack-poisoned]
  • Loading branch information
H-Huang committed Jan 28, 2025
2 parents 481d2ce + 22a474c commit 9f8f576
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 670 deletions.
Loading

0 comments on commit 9f8f576

Please sign in to comment.