Skip to content

Commit

Permalink
TF: unpack inputs on Convbert, GPTJ, LED, and templates (huggingface#…
Browse files Browse the repository at this point in the history
…16491)

* Add unpack_inputs to remaining models

* remove stray use of inputs in the templates; fix tf.debugging of attn masks
  • Loading branch information
gante authored Mar 30, 2022
1 parent ae189ef commit c2f8eaf
Show file tree
Hide file tree
Showing 7 changed files with 466 additions and 933 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,12 +943,12 @@ def call(
# check if head_mask and cross_attn_head_mask have a correct number of layers specified if desired
# The tf.debugging asserts are not compliant with XLA then they
# have to be disabled in other modes than eager.
for attn_mask in [head_mask, cross_attn_head_mask]:
for attn_mask_name, attn_mask in [("head_mask", head_mask), ("cross_attn_head_mask", cross_attn_head_mask)]:
if attn_mask is not None and tf.executing_eagerly():
tf.debugging.assert_equal(
shape_list(attn_mask)[0],
len(self.layers),
message=f"The {attn_mask} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
message=f"The {attn_mask_name} should be specified for {len(self.layers)} layers, but it is for {shape_list(attn_mask)[0]}.",
)

for idx, decoder_layer in enumerate(self.layers):
Expand Down
Loading

0 comments on commit c2f8eaf

Please sign in to comment.