-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix][minor] Change empty shard handling for OSS, do not rely on asserts #460
Conversation
@@ -597,4 +593,4 @@ def _setup_flat_buffers(self) -> None: | |||
else: | |||
self.buckets[device][dst_rank] = bucket | |||
else: | |||
self.buckets[device].append(torch.zeros(1, device=device)) | |||
self.buckets[device].append(torch.zeros(0, device=device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an empty tensor is a thing, better take to catch empty shards
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add this same line as a comment for better code readability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm ok, the comment was about the change actually but if that helps I can always write something. The if clause is if(no params) so I thought that was kind of clear (outside of this very narrow PR view)
@@ -552,8 +545,11 @@ def _broadcast_params(self) -> None: | |||
|
|||
for device in self.buckets.keys(): | |||
for src_rank, bucket in enumerate(self.buckets[device]): | |||
global_src_rank = self.get_global_rank(self.group, src_rank) | |||
last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True) | |||
if bucket.numel() > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about broadcasting something empty for all backends, and does not make a ton of sense, so just skip that
@@ -140,13 +140,6 @@ def partition_parameters(self) -> List[List[dict]]: | |||
param_group_rank["params"] = params | |||
self._partition_parameters[rank].append(param_group_rank) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changing the logic: if a rank has an empty shard, it will be taken care of bucket-wise, no more asserting since that prove fragile on fb infra (the job can be compiled without the asserts)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify: this happens when one of the nodes does not have any params assigned to it? Thats an empty shard right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice. running with python -O is LoL. If that speed things up then the program is probably running too much python code. :-)
005cde0
to
e9454cc
Compare
ee07105
to
02ec477
Compare
Before submitting
What does this PR do?
If a shard is empty, do not assert out but skip the broadcasting step indeed. The prior issue was that if somebody was using "python -O" on a distributed job, this would just hang.
Now if some ranks are empty, they just don't participate in the optimization problem (not updating any tensor), that's all (they can still participate in the data parallel part), which is probably a better take.
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃