-
Notifications
You must be signed in to change notification settings - Fork 970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix dtype bug when offload_state_dict=True
and dtype
is specified
#2116
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Left a few comments. I will run a few tests about bnb cpu offload to see if nothing is broken then merge.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Gentle ping @fxmarty 🤗 |
598a149
to
996107b
Compare
src/accelerate/utils/modeling.py
Outdated
@@ -1311,7 +1313,7 @@ def get_state_dict_offloaded_model(model: nn.Module): | |||
def load_checkpoint_in_model( | |||
model: nn.Module, | |||
checkpoint: Union[str, os.PathLike], | |||
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This typing is wrong as None
is not an accepted value in this function (raises an error at
accelerate/src/accelerate/utils/modeling.py
Line 436 in bd72a5f
if param in device_map: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave it as is because changing to device_map: Dict[str, Union[int, str, torch.device]]
raises errors in the CI.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for noticing this. This should be fixed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating! Overall looks great to me bar one suggestion. cc @SunMarc for a final look through?
@fxmarty could you run |
Done thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Thanks for adding a test.
As per title.
I don't have time to provide a full repro & add a test - could do next week.
But for general context, the issue is the following:
When loading a state dict with
offload_state_dict=True
, we go intoaccelerate/src/accelerate/utils/modeling.py
Line 1403 in bd72a5f
and we are missing a check similar to
accelerate/src/accelerate/utils/modeling.py
Line 292 in bd72a5f
dtype
is specified.The byproduct is that all parameters are casted to
dtype
, although they should not.