Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add head_mask, decoder_head_mask, cross_head_mask to ProphetNet #9964

Merged
merged 9 commits into from
Apr 25, 2021
150 changes: 145 additions & 5 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
Expand Down Expand Up @@ -146,6 +164,12 @@
- 0 for tokens that are **masked**.

`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
Expand Down Expand Up @@ -633,6 +657,7 @@ def forward(
hidden_states,
key_value_states: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
layer_head_mask: Optional[Tensor] = None,
past_key_value: Optional[Tuple[Tensor]] = None,
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -708,6 +733,19 @@ def forward(
attn_weights_reshaped = None

attn_weights = F.softmax(attn_weights, dim=-1)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)

# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped

attn_probs = F.dropout(
attn_weights,
p=self.attention_dropout,
Expand Down Expand Up @@ -797,6 +835,7 @@ def forward(
hidden_states,
past_key_value: Optional[Tuple[Tensor]] = None,
attention_mask=None,
layer_head_mask=None,
extended_predict_attention_mask=None,
main_relative_position_buckets=None,
predict_relative_position_buckets=None,
Expand Down Expand Up @@ -876,6 +915,15 @@ def forward(
onnx_trace=self.onnx_trace,
).type_as(main_attn_weights)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)

main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
Expand Down Expand Up @@ -929,6 +977,18 @@ def forward(
dim=-1,
onnx_trace=self.onnx_trace,
).type_as(predict_attn_weights)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)

predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training)
# project to attention output
# [ngram, B*head, T, c]
Expand Down Expand Up @@ -1063,11 +1123,18 @@ def __init__(self, config: ProphetNetConfig):
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

def forward(self, hidden_states, attention_mask, output_attentions: bool = False):
def forward(
self,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions: bool = False,
):
# 1st residual block
attention_output, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
Expand Down Expand Up @@ -1110,6 +1177,8 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attn_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
extended_predict_attention_mask=None,
main_relative_position_buckets=None,
predict_relative_position_buckets=None,
Expand All @@ -1125,6 +1194,7 @@ def forward(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
Expand All @@ -1141,6 +1211,7 @@ def forward(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attn_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1202,6 +1273,7 @@ def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -1254,7 +1326,12 @@ def forward(
encoder_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

for encoder_layer in self.layers:
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,)

Expand All @@ -1270,10 +1347,14 @@ def custom_forward(*inputs):
create_custom_forward(encoder_layer),
hidden_states,
extended_attention_mask,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1338,6 +1419,8 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
Expand All @@ -1352,6 +1435,12 @@ def forward(
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.

Expand Down Expand Up @@ -1460,6 +1549,12 @@ def forward(
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
present_key_values = () if use_cache else None

# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
# grad cannot be kept because tensor is sliced
Expand Down Expand Up @@ -1491,6 +1586,8 @@ def custom_forward(*inputs):
extended_attention_mask,
encoder_hidden_states,
extended_encoder_attention_mask,
(head_mask[idx] if head_mask is not None else None),
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
extended_predict_attention_mask,
main_relative_position_buckets,
predict_relative_position_buckets,
Expand All @@ -1503,6 +1600,10 @@ def custom_forward(*inputs):
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
Expand Down Expand Up @@ -1678,6 +1779,9 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Tuple] = None,
past_key_values=None,
inputs_embeds=None,
Expand Down Expand Up @@ -1716,6 +1820,7 @@ def forward(
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand All @@ -1728,6 +1833,8 @@ def forward(
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1785,6 +1892,9 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
Expand Down Expand Up @@ -1828,6 +1938,9 @@ def forward(
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -1902,7 +2015,14 @@ def _compute_loss(self, logits, labels, ignore_index=-100):
return loss

def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."

Expand All @@ -1915,6 +2035,7 @@ def prepare_inputs_for_generation(
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache,
}

Expand Down Expand Up @@ -1985,6 +2106,8 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
Expand All @@ -2000,6 +2123,12 @@ def forward(
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.

Expand Down Expand Up @@ -2060,6 +2189,8 @@ def forward(
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down Expand Up @@ -2123,7 +2254,15 @@ def _compute_loss(self, logits, labels, ignore_index=-100):

return loss

def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
**kwargs,
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
Expand All @@ -2134,6 +2273,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"head_mask": head_mask,
"past_key_values": past,
"use_cache": use_cache,
}
Expand Down
3 changes: 0 additions & 3 deletions tests/test_modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = True

def setUp(self):
Expand Down Expand Up @@ -1097,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False

def setUp(self):
Expand Down Expand Up @@ -1126,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False

def setUp(self):
Expand Down