-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BT
] Add fp16 support
#859
[BT
] Add fp16 support
#859
Conversation
@younesbelkada I think the proper solution should be to put back: mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) in the forward instead of it being stateful (currently was always on fp32). WDYT? Taking the reference code in https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html?highlight=scaled_dot_product_attention#torch.nn.functional.scaled_dot_product_attention, I kind of think casting to bool is bad. In any case, the solution I propose should avoid any casting altogether. |
thanks @fxmarty for the heads up, will try now |
mask_value = torch.finfo(value.dtype).min | ||
attention_mask = torch.full([], mask_value, dtype=value.dtype).to(value.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will probably break the logis tests
@fxmarty I think the issue is that sometimes the |
@@ -74,12 +76,15 @@ def wrapped_scaled_dot_product( | |||
torch.bool | |||
) | |||
|
|||
causal_mask = torch.where(causal_mask, 0, self._mask_value) | |||
causal_mask = torch.where(causal_mask, 0, self._mask_value).to(value.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This cast is bad. Can we instead move the definition of mask_value
in the forward?
query = query.to(value.dtype) | ||
key = key.to(value.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed? It should already be of the same dtype no?
The documentation is not available anymore as the PR was closed or merged. |
# gpt-2 | ||
if config.model_type == "gpt2": | ||
target_dtype = self.gpt_layer.c_proj.weight.dtype | ||
# gpt-neo-x | ||
elif config.model_type == "gpt_neox": | ||
target_dtype = self.gpt_layer.dense.weight.dtype | ||
# gpt-j | ||
else: | ||
target_dtype = self.gpt_layer.out_proj.weight.dtype | ||
|
||
self.downcast_qk = config.model_type in ["gptj", "gpt_neox"] | ||
|
||
mask_value = torch.finfo(target_dtype).min | ||
self._mask_value = torch.full([], mask_value, dtype=target_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will IMO not work because the user may use model = model.to(torch.float16)
after initializing the model. Here, self._mask_value
would still be on fp32. I think we really need it in the forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch!
target_dtype = self.gpt_layer.k_proj.weight.dtype | ||
mask_value = torch.finfo(target_dtype).min | ||
self._mask_value = torch.full([], mask_value, dtype=target_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
target_dtype = self.gpt_layer.qkv_proj.weight.dtype | ||
mask_value = torch.finfo(target_dtype).min | ||
self._mask_value = torch.full([], mask_value, dtype=target_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
@@ -74,12 +76,18 @@ def wrapped_scaled_dot_product( | |||
torch.bool | |||
) | |||
|
|||
causal_mask = torch.where(causal_mask, 0, self._mask_value) | |||
causal_mask = torch.where(causal_mask, 0, self._mask_value.to(value.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not equivalent:
import torch
mask_value = torch.finfo(torch.float32).min
mask_value = torch.full([], mask_value, dtype=torch.float32)
casted = mask_value.to(torch.float16)
mask_value = torch.finfo(torch.float16).min
mask_value = torch.full([], mask_value, dtype=torch.float16)
assert torch.equal(casted, mask_value)
not sure if it has any influence or not though. I would just put the definition of mask_value
in the forward directly, as in transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, adapted as suggested!
@fxmarty btw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thank you for iterating on this!
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
What does this PR do?
Currently on the
main
branch the fp16 inference for BetterTransformer decoder models is not supported, this PR aims to fix thisTODO
cc @fxmarty