Skip to content

Commit

Permalink
Deberta V2: Fix critical trace warnings to allow ONNX export (hugging…
Browse files Browse the repository at this point in the history
…face#18272)

* Fix critical trace warnings to allow ONNX export

* Force input to `sqrt` to be float type

* Cleanup code

* Remove unused import statement

* Update model sew

* Small refactor

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Use broadcasting instead of repeat

* Implement suggestion

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Match deberta v2 changes in sew_d

* Improve code quality

* Update code quality

* Consistency of small refactor

* Match changes in sew_d

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
  • Loading branch information
2 people authored and JingyaHuang committed Aug 11, 2022
1 parent 480b9d4 commit 6581aff
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
"""
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - k_ids.repeat(query_size, 1)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = rel_pos_ids.to(torch.long)
Expand Down Expand Up @@ -793,7 +793,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float))
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
Expand All @@ -805,7 +805,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_

# position->content
if "p2c" in self.pos_att_type:
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float))
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position(
key_layer.size(-2),
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
"""
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - k_ids.repeat(query_size, 1)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = rel_pos_ids.to(torch.long)
Expand Down Expand Up @@ -867,7 +867,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1) * scale_factor, dtype=torch.float))
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
Expand All @@ -879,7 +879,7 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_

# position->content
if "p2c" in self.pos_att_type:
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1) * scale_factor, dtype=torch.float))
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position(
key_layer.size(-2),
Expand Down

0 comments on commit 6581aff

Please sign in to comment.