diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index 9912660b6..abcab412c 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -326,7 +326,10 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: self._all_states = [] should_collect_state = self.rank == recipient_rank or recipient_rank == -1 - should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 + should_send_state = self.rank != recipient_rank + + # NCCL requires CUDA tensors for all communication primitives + dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device for rank in range(self.world_size): if rank == self.rank: @@ -340,18 +343,18 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: state_to_share = ( self.optim.state_dict() if should_send_state - else torch.tensor([0], dtype=torch.uint8, device=self._default_device) + else torch.tensor([0], dtype=torch.uint8, device=dist_device) ) broadcast_object( - state_to_share, src_rank=self.global_rank, group=self.group, dist_device=self._default_device, + state_to_share, src_rank=self.global_rank, group=self.group, dist_device=dist_device, ) else: # Fetch the optim state from the other replicas replica_state = broadcast_object( - torch.tensor([0], dtype=torch.uint8, device=self._default_device), + torch.tensor([0], dtype=torch.uint8, device=dist_device), src_rank=self._local_to_global_rank[rank], group=self.group, - dist_device=self._default_device, + dist_device=dist_device, ) if should_collect_state: diff --git a/tests/optim/test_oss.py b/tests/optim/test_oss.py index 87ff6056f..55f356470 100644 --- a/tests/optim/test_oss.py +++ b/tests/optim/test_oss.py @@ -470,6 +470,11 @@ def closure(): _ = optimizer.step(closure=closure) check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False) + # Check that if the model is moved to cpu, the optimizer consolidation still works + model.cpu() + optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99) + optimizer.consolidate_state_dict(recipient_rank=reference_rank) + dist.destroy_process_group()