-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] OSS dict load/save fix - better fix than 383 and unit test #386
Conversation
…3 - 26.404342651367188
@@ -859,10 +861,16 @@ def closure_sharded(input_tensor=input_tensor): | |||
sharded_optim_state_dict = sync_object_ranks(sharded_optim_state_dict, RECIPIENT_RANK, device) | |||
|
|||
# - cross load the states | |||
# run one step and check that the models are still the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this modified unit test does break on the old version
@@ -379,6 +379,8 @@ def state_dict(self) -> Dict[str, Any]: | |||
global_id = self.param_to_index[local_index_to_param_id[local_param_index]] | |||
state_dict["state"][global_id] = s["state"][local_param_index] | |||
|
|||
# Make sure that the parameters are sorted in the state, as expected | |||
state_dict["state"] = dict(sorted(state_dict["state"].items())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the state dict returned was sorted properly under the "param_groups" key, but not under the "state" field, which was following the partitioning. I was assuming when loading that it was sorted, so that would break.
Pytorch just uses the ordering from the "param_groups" key, and I was just testing the loading OSS-> Pytorch and vice versa, so this was not caught unfortunately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know python's dictionary is ordered, so I just looked it up. turns out this has been enabled since python 3.5, good to know! https://stackoverflow.com/questions/39980323/are-dictionaries-ordered-in-python-3-6
|
||
# Only add this state to the sharded optimizer if it owns this param | ||
for pg in self.optim.param_groups: | ||
if id(param) in [id(p) for p in pg["params"]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this second check could mask an issue, we just checked above that this rank owns this param, so this is not needed (and potentially risky)
@@ -832,8 +834,8 @@ def closure_sharded(input_tensor=input_tensor): | |||
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) | |||
|
|||
assert torch.allclose( | |||
loss_ddp, loss_sharded_optim | |||
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" | |||
loss_ddp, loss_sharded_optim, rtol=1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the rtol change is only needed on pytorch 1.5 unfortunately, without that on a two gpu machine the difference becomes
26.404895782470703 vs 26.404342651367188 (which I assume is due to a different casting and not structurally wrong) and this asserts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is worth documenting the reason for 1e-3.
sorry @min-xu-ai for the revert of the previous one, I just thought this was cleaner and there was one fix left in the cold |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, this seems to be much nicer.
@@ -832,8 +834,8 @@ def closure_sharded(input_tensor=input_tensor): | |||
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) | |||
|
|||
assert torch.allclose( | |||
loss_ddp, loss_sharded_optim | |||
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}" | |||
loss_ddp, loss_sharded_optim, rtol=1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is worth documenting the reason for 1e-3.
Before submitting
What does this PR do?
Fixes #380. Better take than #383 because fixing another issue which was not caught (383 was not enough), and reproducing the issue in an updated unit test so that this does not happen again. Thanks again @zhengwy888
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃