-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keep_in_fp32_modules
support
#20683
Add keep_in_fp32_modules
support
#20683
Conversation
What about adding hooks on each converted module, that will take care of converting the input / output to the correct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial comments:
src/transformers/modeling_utils.py
Outdated
if keep_in_fp32_modules is not None and not low_cpu_mem_usage: | ||
# Force `low_cpu_mem_usage` to be set to `True` - check the PR: | ||
logger.warning( | ||
"The argument `keep_in_fp32_modules` is used, force-enabling `low_cpu_mem_usage` to load the model" | ||
) | ||
low_cpu_mem_usage = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't be force-set here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proposed something in 115c0d0
The documentation is not available anymore as the PR was closed or merged. |
- make tests `slow` - fix logic
src/transformers/modeling_utils.py
Outdated
logger.warning( | ||
" `_keep_in_fp32_modules` is not set to `None` and you don't have `accelerate` installed", | ||
" it is recommended to have `accelerate` installed in this case `pip install accelerate`.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The warning should only be trigerred when torch_dtype == torch.float16
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fixed in 8014c34
src/transformers/modeling_utils.py
Outdated
keep_in_fp32_modules = model._keep_in_fp32_modules | ||
if keep_in_fp32_modules is not None: | ||
low_cpu_mem_usage = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this use use_keep_in_fp32_modules
here? Also should go before so the test at line 2307 can be simplified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the tests in 966cc06 but I think that we still need keep_in_fp32_modules = model._keep_in_fp32_modules
as it is used later on in line 2342
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right: 0f75387
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it seems that it's more tricky than that, putting it on top will result in some failing tests should be fixed in 243e6b5
src/transformers/modeling_utils.py
Outdated
# upcast in fp32 if any | ||
target_dtype = dtype | ||
if keep_in_fp32_modules is not None and any( | ||
module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should also add a test of dtype
being float16
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added in 8014c34
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
src/transformers/modeling_utils.py
Outdated
@@ -2299,6 +2323,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
with ContextManagers(init_contexts): | |||
model = cls(config, *model_args, **model_kwargs) | |||
|
|||
if use_keep_in_fp32_modules: | |||
low_cpu_mem_usage = True | |||
keep_in_fp32_modules = model._keep_in_fp32_modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's set it to []
here if it's not None
, so that we don't have to check again layer on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be addressed in cb89c42
src/transformers/modeling_utils.py
Outdated
elif keep_in_fp32_modules is not None and state_dict is not None: | ||
for key in state_dict: | ||
if any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules): | ||
state_dict[key] = state_dict[key].to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not useful as with torch.load_state_dict
, the weights are converted to the dtype inside the model. So it's the model dtype that you should fix here.
Also this removes the necessity for an Accelerate warning above, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Should be addressed in cb89c42
Thanks @younesbelkada and @sgugger !! Tested this locally; can confirm this works with patch 1&2 from #20287 (comment) The only problem I encountered is that in: transformers/src/transformers/modeling_utils.py Line 2326 in 0f75387
You get an error as |
Thanks so much @larsmennen for confirming that the tests pass! We should be close merging this 💪 |
@@ -2070,6 +2070,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
# Load model | |||
loading_info = None | |||
|
|||
# Keep in fp32 modules | |||
keep_in_fp32_modules = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@larsmennen the keyword has been added here if this is what you meant
- add `is_accelerate_available` - fixes pipleine tests that failed
hmm that doesn't fix it. I think you just need to pop the argument from model_kwargs, otherwise it gets passed to the underlying model (i'm assuming you don't want that? but cmiiw) I.e. after transformers/src/transformers/modeling_utils.py Line 1981 in 7d47df2
if you add keep_in_fp32_modules = kwargs.pop("keep_in_fp32_modules", None) I tested w/ that modification on top of 7d47df2 and that works! Thanks for the quick action @younesbelkada ! 🙏 |
@larsmennen how are you loading your model ? The description above is slightly misleading as initially the plan was to add a kwarg when loading the model as follows:
but now this is not needed, you should just load your model like:
|
@younesbelkada ah i see! I was passing the kwarg yes, so that explains. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there, just one last comment and we should be good to merge! Thanks!
src/transformers/modeling_utils.py
Outdated
@@ -2276,11 +2290,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
) | |||
dtype_orig = cls._set_default_torch_dtype(torch_dtype) | |||
|
|||
# Check if `_keep_in_fp32_modules` is not None | |||
use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and is_accelerate_available() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also only relevant if torch_dtype==torch.float16
so maybe add it here?
This is the place to issue a warning I think if:
cls._keep_in_fp32_modules is not None
torch_dtype==torch.float16
is_accelerate_available()
is not true
to tell the user they should install Accelerate to have good predictions from the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed! Proposed your suggestions in 1d21843
@larsmennen this PR will be merged as soon as all the tests will be green ! |
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
All slow tests from T5 (and BLOOM just in case we didn't break anything else) pass 🟢 |
* add `keep_in_fp32_modules` support * pass it as class attribute * few modifs - make tests `slow` - fix logic * better logic * fix failing test * `bfloat16` support * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * simplify tests * simplify tests * fix test * modify message * more checks * fix failing tests * add more conditions - add `is_accelerate_available` - fixes pipleine tests that failed * add suggestions * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix failing `bnb` test * add last safety checker Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
I tried this one, latest version of transformers (27.4), cuda 10.2 and I get this error:
|
|
You need to do somthing like: from transformers import T5ForConditionalGeneration
T5ForConditionalGeneration._keep_in_fp32_modules = ["wo"]
# your code here |
Except this is already done for T5 ;-) |
What does this PR do?
This PR partially addresses #20287 - although half-precision and int8 conversion work extremely well for most of the models, for some architectures (e.g. T5) the casting leads in a drastic performance degradation.
This can be fixed by manually force-casting some modules in
float32
. For FLAN-T5, @larsmennen and @navjotts have found out that keeping only these weights infp32
enables to run largest models in fp16 or int8 with no performance degradation.This PR introduces a new utils in
from_pretrained
method, termed askeep_in_fp32_modules
that partially addresses this issue.How this util works? For T5:
When using
keep_in_fp32_modules
,low_cpu_mem_usage
needs to be force-set toTrue
. This is because iflow_cpu_mem_usage=False
, it is the function from pytorch_load_from_state_dict
that is called under the hood on each sub-module. This function callscopy_
from Pytorch which seems to keep the tensor in its nativedtype
regardless thedtype
of the input moduleKeeping this as a draft for now as this util needs to be manually patched with fixes such as #20287 (comment) , otherwise users will encounter issues about incompatible
dtype
between input and weightscc @sgugger