Skip to content

Commit

Permalink
Tiny TF Bart fixes (#8023)
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik authored Oct 26, 2020
1 parent 0774786 commit 8be9cb0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def _prepare_bart_decoder_inputs(
if decoder_attn_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(tf.Tensor)
decoder_padding_mask = invert_mask(decoder_attn_mask)

causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
return decoder_input_ids, decoder_padding_mask, causal_lm_mask
Expand Down Expand Up @@ -903,6 +903,7 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
training=training,
)
decoder_outputs = self.decoder(
decoder_input_ids,
Expand All @@ -915,6 +916,7 @@ def call(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
# Attention and hidden_states will be [] or None if they aren't needed
Expand Down

0 comments on commit 8be9cb0

Please sign in to comment.