Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Honor model dtype in load_checkpoint #920

Merged
merged 2 commits into from
Dec 20, 2022
Merged

Honor model dtype in load_checkpoint #920

merged 2 commits into from
Dec 20, 2022

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Dec 13, 2022

This PR fixes a standing bug where we have a different behavior than PyTorch. In torch, loading a state_dict inside a model will never change the model's dtype:

import torch

model = torch.nn.Linear(5, 6)
model_in_half = model.half()
state_dict_in_half = model.state_dict()

model = torch.nn.Linear(5, 6).to("meta")
model.load_state_dict(state_dict_in_half)
model.weight.dtype
# return torch.float32

Currently in Accelerate, load_checkpoint does the opposite and when loading a model, it converts it to the dtype of the state dict. This PR addresses that.

This PR only contains the fix for now, we have to discuss how to maybe maintain backward compatibility (even if this is a bug fix), because diffusers might be relying on this behavior, cc @patrickvonplaten

@sgugger sgugger requested a review from muellerzr December 13, 2022 20:03
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 13, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that makes sense. I don't particularly think there is a "harm" in silently pushing it out (i.e. don't advertise the bad behavior but let it still pass) in this particular case. If we do care about phasing that out perhaps leave it for a 1.0.0? (Similar to some optimizer bits we have)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Dec 16, 2022

Actually before merging, could it maybe be better to handle this in set_module_tensor_to_device ? E.g. add a dtype argument to the function there? This would be easier for diffusers to by in line with accelerate I think - see: https://github.com/huggingface/diffusers/blob/727434c206f6c22b746e460293035a1324f0bc13/src/diffusers/modeling_utils.py#L491

break

if old_param is not None:
param = param.to(old_param.dtype)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this not be better done in set_module_tensor_to_device ? Or maybe additionally add a torch_dtype arg to set_module_tensor_to_device that handles the param correctly if value=param is used?

@@ -680,8 +694,7 @@ def load_checkpoint_in_model(
else:
for param_name, param in checkpoint.items():
module_name = param_name
if dtype is not None and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved to set_module_tensor_to_device.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adapting!

@sgugger sgugger merged commit aa53327 into main Dec 20, 2022
@sgugger sgugger deleted the honor_model_dtype branch December 20, 2022 07:48
borzunov added a commit to bigscience-workshop/petals that referenced this pull request Apr 25, 2023
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()`
- In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920)
- Because of that, blocks and attention caches used float32, which caused OOMs
- This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants