-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] OSS - enforce cuda parameters for state consolidation if NCCL backend #573
Conversation
@@ -470,6 +470,11 @@ def closure(): | |||
_ = optimizer.step(closure=closure) | |||
check_same_models_across_ranks(model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=False) | |||
|
|||
# Check that if the model is moved to cpu, the optimizer consolidation still works | |||
model.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without the fix, this unit test does fail with the same error that the user mentioned
@@ -328,6 +328,9 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: | |||
should_collect_state = self.rank == recipient_rank or recipient_rank == -1 | |||
should_send_state = (self.rank != recipient_rank and recipient_rank != -1) or recipient_rank == -1 | |||
|
|||
# NCCL requires CUDA tensors for all communication primitives | |||
dist_device = torch.device("cuda") if self.backend == dist.Backend.NCCL else self._default_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no choice with NCCL, needs to be cuda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the model is moved back to the cpu and the optimizer state reflects it, why do we call broadcast? The optimizer state is not sharded anymore right? Maybe i am missing something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the framework is the one calling .consolidate(), it can do so at any time basically. We could add a skip mechanism for when it's called twice in a row (would be even more foolproof actually), but that would not solve the case of train -> move to cpu -> call .consolidate(), which can be legitimate, if unfortunate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(complement) the issue was that if the model is moved to cpu, then some tensors in the optimizer dict are cpu. When consolidating the shards are exchanged towards a specific rank (or all), which breaks with NCCL since it always expects cuda for communication primitives
the broken test was unrelated, pipe |
@@ -340,18 +343,18 @@ def consolidate_state_dict(self, recipient_rank: int = 0) -> None: | |||
state_to_share = ( | |||
self.optim.state_dict() | |||
if should_send_state | |||
else torch.tensor([0], dtype=torch.uint8, device=self._default_device) | |||
else torch.tensor([0], dtype=torch.uint8, device=dist_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wasteful. Why not skip the broadcast in this case instead of sending a zero? In the else
below you could check if rank != recipient
.
Co-authored-by: msbaines <35972327+msbaines@users.noreply.github.com>
Before submitting
What does this PR do?
Should fix https://fb.workplace.com/groups/pytorchLightning/permalink/1419090048427529/
A usecase that I did not think of was that the model could be moved to cpu() at some point, then OSS state consolidated. In that case the state consolidation would fail because NCCL only supports cuda tensors
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃
cc @ananthsub