Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keep_in_fp32_modules support #20683

Merged
merged 20 commits into from
Dec 13, 2022

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Dec 8, 2022

What does this PR do?

This PR partially addresses #20287 - although half-precision and int8 conversion work extremely well for most of the models, for some architectures (e.g. T5) the casting leads in a drastic performance degradation.

This can be fixed by manually force-casting some modules in float32. For FLAN-T5, @larsmennen and @navjotts have found out that keeping only these weights in fp32 enables to run largest models in fp16 or int8 with no performance degradation.

This PR introduces a new utils in from_pretrained method, termed as keep_in_fp32_modules that partially addresses this issue.

How this util works? For T5:

from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, keep_in_fp32_modules=["wo"])
print(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype)
>>> torch.float32

When using keep_in_fp32_modules , low_cpu_mem_usage needs to be force-set to True. This is because if low_cpu_mem_usage=False, it is the function from pytorch _load_from_state_dict that is called under the hood on each sub-module. This function calls copy_ from Pytorch which seems to keep the tensor in its native dtype regardless the dtype of the input module

import torch

param = torch.Tensor([0.1, 0.2, 0.3]).to(torch.float16)
to_copy_param = torch.Tensor([0.2, 0.1, 0.3]).to(torch.float32)

param.copy_(to_copy_param)
print(param.dtype)
>>> torch.float16

Keeping this as a draft for now as this util needs to be manually patched with fixes such as #20287 (comment) , otherwise users will encounter issues about incompatible dtype between input and weights

cc @sgugger

@younesbelkada younesbelkada requested a review from sgugger December 8, 2022 15:41
@younesbelkada
Copy link
Contributor Author

What about adding hooks on each converted module, that will take care of converting the input / output to the correct dtype ?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments:

Comment on lines 1995 to 2000
if keep_in_fp32_modules is not None and not low_cpu_mem_usage:
# Force `low_cpu_mem_usage` to be set to `True` - check the PR:
logger.warning(
"The argument `keep_in_fp32_modules` is used, force-enabling `low_cpu_mem_usage` to load the model"
)
low_cpu_mem_usage = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be force-set here.

Copy link
Contributor Author

@younesbelkada younesbelkada Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proposed something in 115c0d0

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 8, 2022

The documentation is not available anymore as the PR was closed or merged.

@younesbelkada younesbelkada requested a review from sgugger December 8, 2022 16:42
@younesbelkada
Copy link
Contributor Author

As suggested in #20287 / model loaded in bfloat16 should keep their weights in bfloat16 and not cast them in fp32. This is addressed in e3498da

Comment on lines 2294 to 2297
logger.warning(
" `_keep_in_fp32_modules` is not set to `None` and you don't have `accelerate` installed",
" it is recommended to have `accelerate` installed in this case `pip install accelerate`.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning should only be trigerred when torch_dtype == torch.float16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in 8014c34

Comment on lines 2326 to 2328
keep_in_fp32_modules = model._keep_in_fp32_modules
if keep_in_fp32_modules is not None:
low_cpu_mem_usage = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this use use_keep_in_fp32_modules here? Also should go before so the test at line 2307 can be simplified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the tests in 966cc06 but I think that we still need keep_in_fp32_modules = model._keep_in_fp32_modules as it is used later on in line 2342

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right: 0f75387

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it seems that it's more tricky than that, putting it on top will result in some failing tests should be fixed in 243e6b5

# upcast in fp32 if any
target_dtype = dtype
if keep_in_fp32_modules is not None and any(
module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add a test of dtype being float16 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in 8014c34

younesbelkada and others added 5 commits December 9, 2022 15:35
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@@ -2299,6 +2323,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

if use_keep_in_fp32_modules:
low_cpu_mem_usage = True
keep_in_fp32_modules = model._keep_in_fp32_modules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set it to [] here if it's not None, so that we don't have to check again layer on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be addressed in cb89c42

Comment on lines 2588 to 2591
elif keep_in_fp32_modules is not None and state_dict is not None:
for key in state_dict:
if any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules):
state_dict[key] = state_dict[key].to(torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not useful as with torch.load_state_dict, the weights are converted to the dtype inside the model. So it's the model dtype that you should fix here.

Also this removes the necessity for an Accelerate warning above, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Should be addressed in cb89c42

@larsmennen
Copy link
Contributor

Thanks @younesbelkada and @sgugger !! Tested this locally; can confirm this works with patch 1&2 from #20287 (comment)

The only problem I encountered is that in:

model = cls(config, *model_args, **model_kwargs)

You get an error as keep_in_fp32_modules is an unexpected keyword to the underlying model class (locally i just added it quickly to test). Do you want to add this in so people can use it in their model class to determine where to apply patches like 1&2? Or alternatively don't pass it on and then people can just query the dtype.

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Dec 9, 2022

Thanks so much @larsmennen for confirming that the tests pass! We should be close merging this 💪
I think that your failing test should be fixed with my latest commit ( cb89c42 ) but I am not sure, could you try again with the latest commit? 🙏

@@ -2070,6 +2070,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Load model
loading_info = None

# Keep in fp32 modules
keep_in_fp32_modules = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larsmennen the keyword has been added here if this is what you meant

- add `is_accelerate_available`
- fixes pipleine tests that failed
@larsmennen
Copy link
Contributor

hmm that doesn't fix it. I think you just need to pop the argument from model_kwargs, otherwise it gets passed to the underlying model (i'm assuming you don't want that? but cmiiw)

I.e. after

commit_hash = kwargs.pop("_commit_hash", None)

if you add

keep_in_fp32_modules = kwargs.pop("keep_in_fp32_modules", None)

I tested w/ that modification on top of 7d47df2 and that works! Thanks for the quick action @younesbelkada ! 🙏

@younesbelkada
Copy link
Contributor Author

@larsmennen how are you loading your model ? The description above is slightly misleading as initially the plan was to add a kwarg when loading the model as follows:

from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, keep_in_fp32_modules=["wo"])

but now this is not needed, you should just load your model like:

from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("t5-small", device_map="auto", load_in_8bit=True])

@larsmennen
Copy link
Contributor

@younesbelkada ah i see! I was passing the kwarg yes, so that explains.

@younesbelkada younesbelkada marked this pull request as ready for review December 12, 2022 09:41
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there, just one last comment and we should be good to merge! Thanks!

@@ -2276,11 +2290,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = cls._keep_in_fp32_modules is not None and is_accelerate_available()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also only relevant if torch_dtype==torch.float16 so maybe add it here?

This is the place to issue a warning I think if:

  • cls._keep_in_fp32_modules is not None
  • torch_dtype==torch.float16
  • is_accelerate_available() is not true
    to tell the user they should install Accelerate to have good predictions from the model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! Proposed your suggestions in 1d21843

@younesbelkada
Copy link
Contributor Author

@larsmennen this PR will be merged as soon as all the tests will be green !
Would you mind opening a PR addressing your suggestions (patch 1 & 2 from the discussion at #20287 )?

@younesbelkada
Copy link
Contributor Author

All slow tests from T5 (and BLOOM just in case we didn't break anything else) pass 🟢
Merging once the CI tests are green

@younesbelkada younesbelkada merged commit 1af4bee into huggingface:main Dec 13, 2022
mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
* add `keep_in_fp32_modules` support

* pass it as class attribute

* few modifs

- make tests `slow`
- fix logic

* better logic

* fix failing test

* `bfloat16` support

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix

* simplify tests

* simplify tests

* fix test

* modify message

* more checks

* fix failing tests

* add more conditions

- add `is_accelerate_available`
- fixes pipleine tests that failed

* add suggestions

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix failing `bnb` test

* add last safety checker

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@Oxi84
Copy link

Oxi84 commented Apr 5, 2023

I tried this one, latest version of transformers (27.4), cuda 10.2 and I get this error:

model1a_CPU = T5ForConditionalGeneration.from_pretrained(model_path, low_cpu_mem_usage=True,torch_dtype=torch.float16, keep_in_fp32_modules=["wo"]).to("cuda") TypeError: __init__() got an unexpected keyword argument 'keep_in_fp32_modules'

@sgugger
Copy link
Collaborator

sgugger commented Apr 6, 2023

keep_in_fp32_modules is not an argument you can pass to from_pretrained, this is all done internally.

@younesbelkada younesbelkada deleted the add-fp32-modules branch April 6, 2023 08:06
@younesbelkada
Copy link
Contributor Author

You need to do somthing like:

from transformers import T5ForConditionalGeneration

T5ForConditionalGeneration._keep_in_fp32_modules = ["wo"]

# your code here

@sgugger
Copy link
Collaborator

sgugger commented Apr 6, 2023

Except this is already done for T5 ;-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants