Skip to content

Commit

Permalink
Update base for Update on "[WIP] Add DiLoCo"
Browse files Browse the repository at this point in the history
Still WIP but open to feedback on the API

## API Usage

```python
# LocalSGD example
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
with LocalSGD(manager, model, optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        optimizer.step()
        
# DiLoCo example
model = SimpleModel()
inner_optimizer = torch.optim.AdamW(
    m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
    m.parameters(), lr=0.7, momentum=0.9, nesterov=True
)
manager = create_autospec(Manager)
with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2):
    for inp, label in dataloader:
        loss = model(inp).mean()
        loss.backward()
        inner_optimizer.step()
        # outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user

```

## Changes
- Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum
- Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared
- TODO: should be working, but still validating some tests

discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk

[ghstack-poisoned]
  • Loading branch information
H-Huang committed Jan 28, 2025
1 parent 68d4059 commit 22a474c
Showing 0 changed files with 0 additions and 0 deletions.

0 comments on commit 22a474c

Please sign in to comment.