diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index df3d4d95cd0170..0fbb66ba8fd054 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch DeBERTa model.""" -import math from collections.abc import Sequence from typing import Optional, Tuple, Union @@ -640,8 +639,8 @@ def linear(w, b, x): qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] qkvb = [None] * 3 - q = linear(qkvw[0], qkvb[0], query_states) - k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] + q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype)) + k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) @@ -650,8 +649,8 @@ def linear(w, b, x): rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 + len(self.pos_att_type) - scale = math.sqrt(query_layer.size(-1) * scale_factor) - query_layer = query_layer / scale + scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) @@ -711,13 +710,13 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd if "p2c" in self.pos_att_type: pos_query_layer = self.pos_q_proj(rel_embeddings) pos_query_layer = self.transpose_for_scores(pos_query_layer) - pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) + pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if query_layer.size(-2) != key_layer.size(-2): r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) else: r_pos = relative_pos p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) - p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype)) p2c_att = torch.gather( p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) ).transpose(-1, -2) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 3243ee108d488b..1a9252a7d30707 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -717,7 +717,9 @@ def forward( if "p2c" in self.pos_att_type: scale_factor += 1 scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) - attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor( + scale, dtype=query_layer.dtype + ) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) rel_att = self.disentangled_attention_bias( @@ -799,7 +801,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ dim=-1, index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), ) - score += c2p_att / scale + score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype) # position->content if "p2c" in self.pos_att_type: @@ -822,7 +824,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ dim=-1, index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) - score += p2c_att / scale + score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype) return score diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index fe5836a80f36e0..bcd95139c55898 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -791,7 +791,9 @@ def forward( if "p2c" in self.pos_att_type: scale_factor += 1 scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) - attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor( + scale, dtype=query_layer.dtype + ) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) rel_att = self.disentangled_attention_bias( @@ -873,7 +875,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ dim=-1, index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), ) - score += c2p_att / scale + score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype) # position->content if "p2c" in self.pos_att_type: @@ -896,7 +898,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ dim=-1, index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), ).transpose(-1, -2) - score += p2c_att / scale + score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype) return score