Skip to content

Commit

Permalink
Merge pull request huggingface#1 from patrickvonplaten/fix_tests_in_o…
Browse files Browse the repository at this point in the history
…utput_attentions

fix pytorch tests
  • Loading branch information
patrickvonplaten authored Jun 4, 2020
2 parents c13f4b4 + b825601 commit f51162a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def forward(self, input_ids, attention_mask=None, output_attentions=False):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask)
x, attn = encoder_layer(x, attention_mask, output_attentions=output_attentions)

if output_attentions:
all_attentions.append(attn)
Expand Down Expand Up @@ -830,6 +830,7 @@ def forward(
decoder_padding_mask, causal_mask = None, None

assert decoder_input_ids is not None

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions,
Expand All @@ -843,8 +844,10 @@ def forward(
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
output_attentions=output_attentions,
use_cache=use_cache,
)

# Attention and hidden_states will be [] or None if they aren't needed
decoder_outputs: Tuple = _filter_out_falsey_values(decoder_outputs)
assert isinstance(decoder_outputs[0], torch.Tensor)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def forward(
past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
# Combine self attn and cross attn key value states
Expand Down Expand Up @@ -966,6 +967,7 @@ def forward(
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

if use_cache is True:
Expand Down Expand Up @@ -1117,6 +1119,7 @@ def forward(
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

# insert decoder past at right place
Expand Down
9 changes: 8 additions & 1 deletion src/transformers/modeling_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,14 @@ def forward(
if target_mapping is not None:
q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat=seg_mat,
attn_mask=attn_mask_g,
head_mask=head_mask,
output_attentions=output_attentions,
)

if output_attentions:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def create_and_check_xlnet_base_model_with_att_output(
model.to(torch_device)
model.eval()

_, _, attentions = model(input_ids_1, target_mapping=target_mapping)
_, _, attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)

self.parent.assertEqual(len(attentions), config.n_layer)
self.parent.assertIsInstance(attentions[0], tuple)
Expand Down

0 comments on commit f51162a

Please sign in to comment.