Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BT] Add fp16 support #859

Merged
merged 15 commits into from
Mar 7, 2023

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Mar 7, 2023

What does this PR do?

Currently on the main branch the fp16 inference for BetterTransformer decoder models is not supported, this PR aims to fix this

TODO

  • add tests

cc @fxmarty

@fxmarty
Copy link
Contributor

fxmarty commented Mar 7, 2023

@younesbelkada I think the proper solution should be to put back:

mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)

in the forward instead of it being stateful (currently was always on fp32). WDYT?

Taking the reference code in https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention, I kind of think casting to bool is bad. In any case, the solution I propose should avoid any casting altogether.

@younesbelkada
Copy link
Contributor Author

thanks @fxmarty for the heads up, will try now

Comment on lines 131 to 132
mask_value = torch.finfo(value.dtype).min
attention_mask = torch.full([], mask_value, dtype=value.dtype).to(value.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably break the logis tests

@younesbelkada
Copy link
Contributor Author

younesbelkada commented Mar 7, 2023

@fxmarty I think the issue is that sometimes the attention_mask is provided on the forward pass, thus we need it to cast it in this case no?
I added logits tests as well FYI

@@ -74,12 +76,15 @@ def wrapped_scaled_dot_product(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, self._mask_value).to(value.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This cast is bad. Can we instead move the definition of mask_value in the forward?

Comment on lines 85 to 86
query = query.to(value.dtype)
key = key.to(value.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? It should already be of the same dtype no?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 7, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 41 to 54
# gpt-2
if config.model_type == "gpt2":
target_dtype = self.gpt_layer.c_proj.weight.dtype
# gpt-neo-x
elif config.model_type == "gpt_neox":
target_dtype = self.gpt_layer.dense.weight.dtype
# gpt-j
else:
target_dtype = self.gpt_layer.out_proj.weight.dtype

self.downcast_qk = config.model_type in ["gptj", "gpt_neox"]

mask_value = torch.finfo(target_dtype).min
self._mask_value = torch.full([], mask_value, dtype=target_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will IMO not work because the user may use model = model.to(torch.float16) after initializing the model. Here, self._mask_value would still be on fp32. I think we really need it in the forward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Comment on lines 121 to 123
target_dtype = self.gpt_layer.k_proj.weight.dtype
mask_value = torch.finfo(target_dtype).min
self._mask_value = torch.full([], mask_value, dtype=target_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Comment on lines 188 to 190
target_dtype = self.gpt_layer.qkv_proj.weight.dtype
mask_value = torch.finfo(target_dtype).min
self._mask_value = torch.full([], mask_value, dtype=target_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@@ -74,12 +76,18 @@ def wrapped_scaled_dot_product(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, self._mask_value)
causal_mask = torch.where(causal_mask, 0, self._mask_value.to(value.dtype))
Copy link
Contributor

@fxmarty fxmarty Mar 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not equivalent:

import torch

mask_value = torch.finfo(torch.float32).min
mask_value = torch.full([], mask_value, dtype=torch.float32)
casted = mask_value.to(torch.float16)

mask_value = torch.finfo(torch.float16).min
mask_value = torch.full([], mask_value, dtype=torch.float16)
assert torch.equal(casted, mask_value)

not sure if it has any influence or not though. I would just put the definition of mask_value in the forward directly, as in transformers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, adapted as suggested!

@younesbelkada
Copy link
Contributor Author

@fxmarty btw self._mask_value seems to be not used for T5 and OPT, shall we remove them?

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thank you for iterating on this!

younesbelkada and others added 2 commits March 7, 2023 13:47
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
@younesbelkada younesbelkada merged commit 5a7b923 into huggingface:main Mar 7, 2023
@younesbelkada younesbelkada deleted the add-bt-fp16-support branch March 7, 2023 13:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants