Skip to content

Commit

Permalink
[fix] OSS - enforce cuda parameters for state consolidation if NCCL b…
Browse files Browse the repository at this point in the history
…ackend (#573)
  • Loading branch information
blefaudeux authored Apr 4, 2021
1 parent 04001e7 commit 8855337
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
13 changes: 8 additions & 5 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:

self._all_states = []
should_collect_state = self.rank == recipient_rank or recipient_rank == -1
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1
should_send_state = self.rank != recipient_rank

# NCCL requires CUDA tensors for all communication primitives
dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device

for rank in range(self.world_size):
if rank == self.rank:
Expand All @@ -340,18 +343,18 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
state_to_share = (
self.optim.state_dict()
if should_send_state
else torch.tensor([0], dtype=torch.uint8, device=self._default_device)
else torch.tensor([0], dtype=torch.uint8, device=dist_device)
)
broadcast_object(
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=self._default_device,
state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device,
)
else:
# Fetch the optim state from the other replicas
replica_state = broadcast_object(
torch.tensor([0], dtype=torch.uint8, device=self._default_device),
torch.tensor([0], dtype=torch.uint8, device=dist_device),
src_rank=self._local_to_global_rank[rank],
group=self.group,
dist_device=self._default_device,
dist_device=dist_device,
)

if should_collect_state:
Expand Down
5 changes: 5 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ def closure():
_ = optimizer.step(closure=closure)
check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False)

# Check that if the model is moved to cpu, the optimizer consolidation still works
model.cpu()
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99)
optimizer.consolidate_state_dict(recipient_rank=reference_rank)

dist.destroy_process_group()


Expand Down

0 comments on commit 8855337

Please sign in to comment.