Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ T5] fix fp16 loading issue #20878

Merged
merged 4 commits into from
Dec 26, 2022
Merged

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 22, 2022

What does this PR do?

This PR mainly fixes https://github.com/huggingface/transformers/actions/runs/3754402958/jobs/6378652143

Since the PR huggingface/accelerate#920 has been merged, the fix proposed in #20760 seems to not work anymore using the main branch of accelerate for some specific cases.

To reproduce (use the main branch of accelerate):

import torch
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
print(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype)
>>> torch.float16

Why?

I believe this is because the aforementioned PR introduced a new argument dtype on the function set_module_tensor_to_device, if this argument is set to None (by default), the target value is automatically set to the dtype of the old tensor - which slightly breaks some assumptions made in #20760
I believe upstreaming this change on modeling_utils by adding the support of this new argument should be the fix. As some users might not use the latest version of accelerate, I added a small hack to make this change backward compatible, but I am not sure if this is the best solution

Tested this fix on the main branch of accelerate, accelerate==0.15.0 and all relevant tests pass

cc @sgugger

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 22, 2022

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada added the Core: Modeling Internals of the library; Models. label Dec 22, 2022
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing! LGTM with just one nit.

force_upcast_dtype = torch.float32

# For backward compatibility with older versions of `accelerate`
if set_module_tensor_to_device.__code__.co_argcount == 5:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight nit: can we use the signature and parameter names using inspect? It would be clearer to read. Also add a TODO that this should become a version check at the next version of Accelerate (I will take care of it after next release).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Should be addressed in 95486c3

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
@younesbelkada younesbelkada merged commit accad48 into huggingface:main Dec 26, 2022
MKhalusova pushed a commit to MKhalusova/transformers that referenced this pull request Dec 28, 2022
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jan 4, 2023
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
venkat-natchi pushed a commit to venkat-natchi/transformers that referenced this pull request Jan 22, 2023
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
miyu386 pushed a commit to miyu386/transformers that referenced this pull request Feb 9, 2023
* fix fp16 loading issue

* add backward compatibility

* better refactor

* better readability

- remove `force_upcast_dtype` as it is used once
- use `inspect`
- add `TODO`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Modeling Internals of the library; Models.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants