Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hardcoded float dtypes in DeBERTa model, which caused multiple RuntimeErrors in bfloat16 #35336

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bauwenst
Copy link

What does this PR do?

Fix #35332 by removing any hardcoded float dtypes.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

@Rocketknight1
Copy link
Member

Hi @bauwenst, I think bias terms for several models, and sensitive computations like Attention and RoPE, are intended to be kept in float32, even when the overall model is bfloat16.

I agree that there's a bug here: The hidden states are forced to float32 because of the bias addition, but I think the right solution is just to cast the output back to self.dtype afterwards, and not reduce the bias precision, and definitely not reduce the precision in RoPE/Attention!

@bauwenst
Copy link
Author

I think the right solution is just to cast the output back to self.dtype afterwards, and not reduce the bias precision, and definitely not reduce the precision in RoPE/Attention!

@Rocketknight1 Isn't the point of training in bfloat16 that operations are done with 16 bits of precision? What you have laid out now is a case for not using bfloat16 at all because it indeed reduces precision in computations and in weights, but that's expected, no? Do you have any source that points out that it is particularly important for the bias to have more bits of precision than the rest of the model, rather than e.g. dense weights? To take it to an extreme: you could have your whole model be float32 and then cast down to bfloat16 whenever it is time to output hidden states. I don't see why the biases are special.

and sensitive computations like Attention and RoPE, are intended to be kept in float32, even when the overall model is bfloat16.

I doubt that this is intended in DeBERTa. When you let the bias follow the dtype of the reset of the model, everything stays bfloat16 in the attention computation. If it was intended to do attention with higher precision, surely an explicit cast operation would have been put in, rather than accidentally making the query vectors (and yet not the key vectors) float32 as a byproduct of a simple addition?

@Rocketknight1
Copy link
Member

Rocketknight1 commented Dec 19, 2024

Hi @bauwenst, float32 biases are common for two reasons:

  1. They are very small and do not take part in any multiplication operations, so leaving them in float32 does not affect performance and only has a tiny effect on memory usage
  2. Precision errors in biases have a much larger impact than precision errors in weights. Precision errors in weights tend to "cancel out" during matrix multiplication, and each individual weight is multiplied by a (usually small) input activation, which means its effect on the output value is also small. Biases are directly added to the final total, so their impact can be very large.

If you use quantization libraries like llama.cpp you will see this in action - weight matrices for feedforward layers lose the most precision, matrices in attention layers get higher precision, and small tensors that directly modify activations (like biases and LayerNorm weight vectors) are kept in full float32 precision. It is also common for many models in transformers to upcast to float32 for LayerNorm or attention computations.

@bauwenst
Copy link
Author

@Rocketknight1 Okay, that makes some sense; do note that DeBERTa does not seem to be one of these models that upcast to do attention with higher precision.

I have two questions in that case:

  1. Is there a model in transformers that has fixed-dtype biases but does not have the type errors fixed in my commit? This would then be the reference implementation for what to do with DeBERTa.
  2. Should DeBERTa explicitly upcast its precision in its attention computation, given that this is a big deal?

@Rocketknight1
Copy link
Member

Hi @bauwenst, regarding 1, I think you can just add the bias term and then cast the float32 answer to self.dtype . I don't know if there's a good reference model for this. ModernBERT is the most up-to-date MLM model in transformers but it just uses the built-in bias in nn.Linear .

As for question 2, I'm afraid I don't have an answer! In general, models are added to transformers with the goal of matching their original implementation exactly. In the case of DeBERTa, the model was written 4 years ago, before bfloat16 training was used, and it uses an unusual attention mechanism, so the "right" way to handle that with reduced precision is probably an open research question, lol. If you check the codebase, most modern models do explicitly upcast to float32 for the softmax computation, because softmax can be numerically unstable, but there is some variation in whether they upcast other parts of attention.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for opening the PR.
The main reason we usually don't patch this is because this was hardcoded by the original authors, and thus we would be changing the results for people that use to rely on the wrong behavior.

Fixing the mask filling does make sense tho! We can do it 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DeBERTa's DisentangledSelfAttention hardcodes float dtype, which causes bfloat16 overflow error
3 participants