Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeBERTa's DisentangledSelfAttention hardcodes float dtype, which causes bfloat16 overflow error #35332

Open
2 of 4 tasks
bauwenst opened this issue Dec 19, 2024 · 1 comment · May be fixed by #35336
Open
2 of 4 tasks
Labels

Comments

@bauwenst
Copy link

System Info

transformers: 4.47.0
Python: 3.10.5
PyTorch: 2.5.1+cu124
GPU: NVIDIA GTX 980 Ti

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm training a DebertaForMaskedLM model with a broader experimental framework, but you can reproduce the bug with simple inference as follows: instantiate such a model with datatype bfloat16, and send a batch through it.

import torch
from transformers import DebertaConfig, DebertaForMaskedLM

model = DebertaForMaskedLM._from_config(DebertaConfig(), torch_dtype=torch.bfloat16)
model(**{"input_ids": torch.tensor([[101,102,103,104]]),
         "attention_mask": torch.tensor([[1,1,1,1]])})

One of two errors is now thrown in modeling_deberta.py, both in DisentangledSelfAttention.forward() (and they can both be traced back to the same issue):

  1. RuntimeError: expected m1 and m2 to have the same dtype, but got: float != struct c10::BFloat16
  2. RuntimeError: value cannot be converted to type at::BFloat16 without overflow

Here's where they come from: two fields in DeBERTa's DisentangledSelfAttention are constructed by explicitly declaring their dtype as torch.float:

self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))

Then, in forward(), we create the two tensors query_layer and key_layer that start out with the dtype of the hidden states, which have the dtype of the model, namely bfloat16:

qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)

But then, one of these tensors, query_layer, is modified by adding self.q_bias into it. The resulting tensor inherits the torch.float data type:

query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])

The first RuntimeError can occur on the following line, when query_layer (now torch.float) and key_layer (still torch.bfloat16) are multiplied. I've had this line crash on one machine and work on another, so perhaps this kind of mixed precision sometimes works.

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

The second RuntimeError occurs even when mixed precision is supported. It happens on the following line:

attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)

attention_scores is of type bfloat16. You then ask to fill it with the minimal value for the data type of query_layer, not the data type of attention_scores. Because query_layer.dtype is torch.float, that minimal value (-3.40282e+38) is more negative than the most negative torch.bfloat16 (-3.38953e+38). Hence, the overflow.

Expected behavior

The dtype of self.q_bias and self.v_bias should be set like the rest of the modules/tensors in the model, rather than being hardcoded. That would keep everything bfloat16.

@ArthurZucker
Copy link
Collaborator

Reviewed the PR! We should make sure this change is transparent as much as possible !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
2 participants