-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Honor model dtype in load_checkpoint
#920
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that makes sense. I don't particularly think there is a "harm" in silently pushing it out (i.e. don't advertise the bad behavior but let it still pass) in this particular case. If we do care about phasing that out perhaps leave it for a 1.0.0? (Similar to some optimizer bits we have)
Actually before merging, could it maybe be better to handle this in |
src/accelerate/utils/modeling.py
Outdated
break | ||
|
||
if old_param is not None: | ||
param = param.to(old_param.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this not be better done in set_module_tensor_to_device
? Or maybe additionally add a torch_dtype
arg to set_module_tensor_to_device
that handles the param correctly if value=param
is used?
@@ -680,8 +694,7 @@ def load_checkpoint_in_model( | |||
else: | |||
for param_name, param in checkpoint.items(): | |||
module_name = param_name | |||
if dtype is not None and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is moved to set_module_tensor_to_device
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adapting!
- After #285, `load_pretrained_block()` uses `accelerate.utils.set_module_tensor_to_device()` - In accelerate>=0.16.0, it saves the tensor in the dtype previously used by the model instead of dtype of the weights (huggingface/accelerate#920) - Because of that, blocks and attention caches used float32, which caused OOMs - This PR makes `load_pretrained_block()` respect `torch_dtype` (default: `"auto"`, which means reading `torch_dtype` from `config.json`)
This PR fixes a standing bug where we have a different behavior than PyTorch. In torch, loading a
state_dict
inside a model will never change the model's dtype:Currently in Accelerate,
load_checkpoint
does the opposite and when loading a model, it converts it to the dtype of the state dict. This PR addresses that.This PR only contains the fix for now, we have to discuss how to maybe maintain backward compatibility (even if this is a bug fix), because
diffusers
might be relying on this behavior, cc @patrickvonplaten