-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ T5
] fix fp16 loading issue
#20878
[ T5
] fix fp16 loading issue
#20878
Conversation
The documentation is not available anymore as the PR was closed or merged. |
8bf5c89
to
43006f0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing! LGTM with just one nit.
src/transformers/modeling_utils.py
Outdated
force_upcast_dtype = torch.float32 | ||
|
||
# For backward compatibility with older versions of `accelerate` | ||
if set_module_tensor_to_device.__code__.co_argcount == 5: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight nit: can we use the signature and parameter names using inspect
? It would be clearer to read. Also add a TODO that this should become a version check at the next version of Accelerate (I will take care of it after next release).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Should be addressed in 95486c3
- remove `force_upcast_dtype` as it is used once - use `inspect` - add `TODO`
* fix fp16 loading issue * add backward compatibility * better refactor * better readability - remove `force_upcast_dtype` as it is used once - use `inspect` - add `TODO`
* fix fp16 loading issue * add backward compatibility * better refactor * better readability - remove `force_upcast_dtype` as it is used once - use `inspect` - add `TODO`
* fix fp16 loading issue * add backward compatibility * better refactor * better readability - remove `force_upcast_dtype` as it is used once - use `inspect` - add `TODO`
* fix fp16 loading issue * add backward compatibility * better refactor * better readability - remove `force_upcast_dtype` as it is used once - use `inspect` - add `TODO`
* fix fp16 loading issue * add backward compatibility * better refactor * better readability - remove `force_upcast_dtype` as it is used once - use `inspect` - add `TODO`
What does this PR do?
This PR mainly fixes https://github.com/huggingface/transformers/actions/runs/3754402958/jobs/6378652143
Since the PR huggingface/accelerate#920 has been merged, the fix proposed in #20760 seems to not work anymore using the main branch of
accelerate
for some specific cases.To reproduce (use the main branch of
accelerate
):Why?
I believe this is because the aforementioned PR introduced a new argument
dtype
on the functionset_module_tensor_to_device
, if this argument is set toNone
(by default), the target value is automatically set to thedtype
of the old tensor - which slightly breaks some assumptions made in #20760I believe upstreaming this change on
modeling_utils
by adding the support of this new argument should be the fix. As some users might not use the latest version of accelerate, I added a small hack to make this change backward compatible, but I am not sure if this is the best solutionTested this fix on the main branch of
accelerate
,accelerate==0.15.0
and all relevant tests passcc @sgugger