diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index af32b954005e65..a029828763e438 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -304,7 +304,7 @@ def forward(self, input_ids, attention_mask=None, output_attentions=False): if self.training and (dropout_probability < self.layerdrop): # skip the layer attn = None else: - x, attn = encoder_layer(x, attention_mask) + x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions) if output_attentions: all_attentions.append(attn) @@ -830,6 +830,7 @@ def forward( decoder_padding_mask, causal_mask = None, None assert decoder_input_ids is not None + if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, @@ -843,8 +844,10 @@ def forward( decoder_padding_mask, decoder_causal_mask=causal_mask, decoder_cached_states=decoder_cached_states, + output_attentions=output_attentions, use_cache=use_cache, ) + # Attention and hidden_states will be [] or None if they aren't needed decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs) assert isinstance(decoder_outputs[0], torch.Tensor) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 1a8940b9c48e3c..b41cee48c38e21 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -539,6 +539,7 @@ def forward( past_key_value_state=cross_attn_past_key_value_state, query_length=query_length, use_cache=use_cache, + output_attentions=output_attentions, ) hidden_states = cross_attention_outputs[0] # Combine self attn and cross attn key value states @@ -966,6 +967,7 @@ def forward( encoder_attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, + output_attentions=output_attentions, ) if use_cache is True: @@ -1117,6 +1119,7 @@ def forward( encoder_attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, + output_attentions=output_attentions, ) # insert decoder past at right place diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 839388510bec93..f472e886ab839c 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -374,7 +374,14 @@ def forward( if target_mapping is not None: q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping) attn_vec_g = self.rel_attn_core( - q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask + q_head_g, + k_head_h, + v_head_h, + k_head_r, + seg_mat=seg_mat, + attn_mask=attn_mask_g, + head_mask=head_mask, + output_attentions=output_attentions, ) if output_attentions: diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index 99696ca5323db7..1e533939169fcc 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -238,7 +238,7 @@ def create_and_check_xlnet_base_model_with_att_output( model.to(torch_device) model.eval() - _, _, attentions = model(input_ids_1, target_mapping=target_mapping) + _, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True) self.parent.assertEqual(len(attentions), config.n_layer) self.parent.assertIsInstance(attentions[0], tuple)