Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cross-attention head mask for Torch encoder-decoder models #10605

Merged
merged 19 commits into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 50 additions & 34 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def forward(
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
`(encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
Expand Down Expand Up @@ -368,7 +368,7 @@ def forward(
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
encoder_layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
Expand All @@ -382,9 +382,9 @@ def forward(
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of
size `(config.encoder_attention_heads,)`.
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
Expand Down Expand Up @@ -419,7 +419,7 @@ def forward(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=encoder_layer_head_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -598,18 +598,25 @@ def __init_subclass__(self):
If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and
modify to your needs. See diagram 1 in `the paper <https://arxiv.org/abs/1910.13461>`__ for more
information on the default strategy.
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
- 0 indicates the head is **masked**.

decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0,
1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
Expand Down Expand Up @@ -710,11 +717,11 @@ def forward(
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
- 0 indicates the head is **masked**.

inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
Expand Down Expand Up @@ -875,7 +882,7 @@ def forward(
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
Expand Down Expand Up @@ -912,18 +919,18 @@ def forward(
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
- 0 indicates the head is **masked**.

encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
cross-attention on hidden heads. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
- 0 indicates the head is **masked**.

past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
Expand Down Expand Up @@ -993,11 +1000,12 @@ def forward(
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None

# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
Expand Down Expand Up @@ -1031,7 +1039,7 @@ def custom_forward(*inputs):
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
encoder_head_mask[idx] if encoder_head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
Expand All @@ -1042,7 +1050,9 @@ def custom_forward(*inputs):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
Expand Down Expand Up @@ -1123,6 +1133,7 @@ def forward(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
Expand Down Expand Up @@ -1172,7 +1183,7 @@ def forward(
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
encoder_head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
Expand Down Expand Up @@ -1248,6 +1259,7 @@ def forward(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
Expand Down Expand Up @@ -1282,6 +1294,7 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
Expand Down Expand Up @@ -1386,6 +1399,7 @@ def forward(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
Expand Down Expand Up @@ -1416,6 +1430,7 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
Expand Down Expand Up @@ -1496,6 +1511,7 @@ def forward(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
start_positions=None,
end_positions=None,
Expand Down Expand Up @@ -1527,6 +1543,7 @@ def forward(
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
Expand Down Expand Up @@ -1633,7 +1650,7 @@ def forward(
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
Expand Down Expand Up @@ -1666,18 +1683,17 @@ def forward(
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
- 0 indicates the head is **masked**.

encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
- 0 indicates the head is **masked**.

past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
Expand Down Expand Up @@ -1734,7 +1750,7 @@ def forward(
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down
Loading