Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Still WIP but open to feedback on the API ## API Usage ```python # LocalSGD example model = SimpleModel() optimizer = optim.SGD(model.parameters()) manager = create_autospec(Manager) with LocalSGD(manager, model, optimizer, sync_every=2): for inp, label in dataloader: loss = model(inp).mean() loss.backward() optimizer.step() # DiLoCo example model = SimpleModel() inner_optimizer = torch.optim.AdamW( m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) ) outer_optimizer = torch.optim.SGD( m.parameters(), lr=0.7, momentum=0.9, nesterov=True ) manager = create_autospec(Manager) with DiLoCo(manager, m, inner_optimizer, outer_optimizer, sync_every=2): for inp, label in dataloader: loss = model(inp).mean() loss.backward() inner_optimizer.step() # outer_optimizer is actually used every 'sync_every' steps but this is hidden from the user ``` ## Changes - Updated `LocalSGD` to be a context manager rather than a `nn.Module` wrapper. This required adding a pre_forward_hook to the model start the quorum - Added DiLoCo. This is a subclass of LocalSGD since a lot of code is shared - TODO: should be working, but still validating some tests discussion doc: https://docs.google.com/document/d/11c5JwQpSzilrDvK-vNsgQhpXAihbMn-hTRC8y3LiGqY/edit?tab=t.0#heading=h.izo4yi6jz4mk [ghstack-poisoned]
- Loading branch information