From 6581affea99c83168212f9799fe349eedd23d182 Mon Sep 17 00:00:00 2001 From: iiLaurens Date: Thu, 11 Aug 2022 15:54:43 +0200 Subject: [PATCH] Deberta V2: Fix critical trace warnings to allow ONNX export (#18272) * Fix critical trace warnings to allow ONNX export * Force input to `sqrt` to be float type * Cleanup code * Remove unused import statement * Update model sew * Small refactor Co-authored-by: Michael Benayoun * Use broadcasting instead of repeat * Implement suggestion Co-authored-by: Michael Benayoun * Match deberta v2 changes in sew_d * Improve code quality * Update code quality * Consistency of small refactor * Match changes in sew_d Co-authored-by: Michael Benayoun --- src/transformers/models/deberta_v2/modeling_deberta_v2.py | 6 +++--- src/transformers/models/sew_d/modeling_sew_d.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 79201763a167e0..987827ea08c886 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -584,7 +584,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- """ q_ids = torch.arange(0, query_size) k_ids = torch.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - k_ids.repeat(query_size, 1) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = rel_pos_ids.to(torch.long) @@ -793,7 +793,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -805,7 +805,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2), diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index a5b52881e8023f..36a30fde7cdd08 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -229,7 +229,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- """ q_ids = torch.arange(0, query_size) k_ids = torch.arange(0, key_size) - rel_pos_ids = q_ids[:, None] - k_ids.repeat(query_size, 1) + rel_pos_ids = q_ids[:, None] - k_ids[None, :] if bucket_size > 0 and max_position > 0: rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = rel_pos_ids.to(torch.long) @@ -867,7 +867,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ score = 0 # content->position if "c2p" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_att = torch.gather( @@ -879,7 +879,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # position->content if "p2c" in self.pos_att_type: - scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float)) + scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor) if key_layer.size(-2) != query_layer.size(-2): r_pos = build_relative_position( key_layer.size(-2),