Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hardcoded float dtypes in DeBERTa model, which caused multiple RuntimeErrors in bfloat16 #35336

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, size, eps=1e-12):

def forward(self, hidden_states):
input_type = hidden_states.dtype
hidden_states = hidden_states.float()
hidden_states = hidden_states.float() # TODO: Even when working in bfloat16?
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
Expand Down Expand Up @@ -134,7 +134,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
# Full credits to @Szustarol
@torch.jit.script
def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=query_layer.dtype) * scale_factor)


@torch.jit.script
Expand Down Expand Up @@ -184,8 +184,8 @@ def __init__(self, config):
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
self.q_bias = nn.Parameter(torch.zeros((self.all_head_size)))
self.v_bias = nn.Parameter(torch.zeros((self.all_head_size)))
self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []

self.relative_attention = getattr(config, "relative_attention", False)
Expand Down Expand Up @@ -271,8 +271,7 @@ def forward(
rel_att: int = 0
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1 + len(self.pos_att_type)
scale = scaled_size_sqrt(query_layer, scale_factor)
query_layer = query_layer / scale.to(dtype=query_layer.dtype)
query_layer = query_layer / scaled_size_sqrt(query_layer, scale_factor)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.relative_attention and rel_embeddings is not None and relative_pos is not None:
Expand All @@ -287,7 +286,7 @@ def forward(
attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

attention_mask = attention_mask.bool()
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(attention_scores.dtype).min)
# bsz x height x length x dimension
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

Expand Down Expand Up @@ -1133,7 +1132,7 @@ def forward(
)
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
loss = loss_fct(labeled_logits.view(-1, self.num_labels).to(dtype=encoder_layer.dtype), labels.view(-1))
else:
loss = torch.tensor(0).to(logits)
else:
Expand Down