Skip to content

Commit

Permalink
Add head_mask, decoder_head_mask, cross_head_mask to ProphetNet (hugg…
Browse files Browse the repository at this point in the history
…ingface#9964)

* Add head_mask & decoder_head_mask + some corrections

* Fix head masking for N-grams

* Enable test_headmasking for encoder and decod

* Fix one typo regarding in modeling_propgetnet.py

* Enable test_headmasking for ProphetNetStandaloneDecoderModelTest
and ProphetNetStandaloneEncoderModelTest in test_modeling_prophetnet.py

* make style

* Fix cross_head_mask

* Fix attention head mask naming

* `cross_head_mask` -> `cross_attn_head_mask`

* `cross_layer_head_mask` -> `cross_attn_layer_head_mask`

* Still need to merge huggingface#10605 to master to pass the tests
  • Loading branch information
stancld authored and PhilipMay committed Apr 26, 2021
1 parent 05e9d86 commit d967cdf
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 8 deletions.
150 changes: 145 additions & 5 deletions src/transformers/models/prophetnet/modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
Expand Down Expand Up @@ -146,6 +164,12 @@
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
Expand Down Expand Up @@ -633,6 +657,7 @@ def forward(
hidden_states,
key_value_states: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
layer_head_mask: Optional[Tensor] = None,
past_key_value: Optional[Tuple[Tensor]] = None,
output_attentions: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -708,6 +733,19 @@ def forward(
attn_weights_reshaped = None

attn_weights = F.softmax(attn_weights, dim=-1)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
batch_size, self.num_attn_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len)

# apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped

attn_probs = F.dropout(
attn_weights,
p=self.attention_dropout,
Expand Down Expand Up @@ -797,6 +835,7 @@ def forward(
hidden_states,
past_key_value: Optional[Tuple[Tensor]] = None,
attention_mask=None,
layer_head_mask=None,
extended_predict_attention_mask=None,
main_relative_position_buckets=None,
predict_relative_position_buckets=None,
Expand Down Expand Up @@ -876,6 +915,15 @@ def forward(
onnx_trace=self.onnx_trace,
).type_as(main_attn_weights)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
batch_size, self.num_attn_heads, -1, sequence_length
)
main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length)

main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
# project to attn_output
main_attn_output = torch.bmm(main_attn_probs, main_value_states)
Expand Down Expand Up @@ -929,6 +977,18 @@ def forward(
dim=-1,
onnx_trace=self.onnx_trace,
).type_as(predict_attn_weights)

if layer_head_mask is not None:
assert layer_head_mask.size() == (
self.num_attn_heads,
), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}"
predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view(
self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length
)
predict_attn_probs = predict_attn_probs.view(
self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length
)

predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training)
# project to attention output
# [ngram, B*head, T, c]
Expand Down Expand Up @@ -1063,11 +1123,18 @@ def __init__(self, config: ProphetNetConfig):
self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
self.feed_forward_layer_norm = LayerNorm(config.hidden_size)

def forward(self, hidden_states, attention_mask, output_attentions: bool = False):
def forward(
self,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions: bool = False,
):
# 1st residual block
attention_output, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
Expand Down Expand Up @@ -1110,6 +1177,8 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attn_mask=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
extended_predict_attention_mask=None,
main_relative_position_buckets=None,
predict_relative_position_buckets=None,
Expand All @@ -1125,6 +1194,7 @@ def forward(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
Expand All @@ -1141,6 +1211,7 @@ def forward(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attn_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1202,6 +1273,7 @@ def forward(
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -1254,7 +1326,12 @@ def forward(
encoder_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None

for encoder_layer in self.layers:
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,)

Expand All @@ -1270,10 +1347,14 @@ def custom_forward(*inputs):
create_custom_forward(encoder_layer),
hidden_states,
extended_attention_mask,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1338,6 +1419,8 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
Expand All @@ -1352,6 +1435,12 @@ def forward(
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
Expand Down Expand Up @@ -1460,6 +1549,12 @@ def forward(
all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
present_key_values = () if use_cache else None

# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (
len(self.layers)
), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
# grad cannot be kept because tensor is sliced
Expand Down Expand Up @@ -1491,6 +1586,8 @@ def custom_forward(*inputs):
extended_attention_mask,
encoder_hidden_states,
extended_encoder_attention_mask,
(head_mask[idx] if head_mask is not None else None),
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
extended_predict_attention_mask,
main_relative_position_buckets,
predict_relative_position_buckets,
Expand All @@ -1503,6 +1600,10 @@ def custom_forward(*inputs):
attention_mask=extended_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attn_mask=extended_encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
extended_predict_attention_mask=extended_predict_attention_mask,
main_relative_position_buckets=main_relative_position_buckets,
predict_relative_position_buckets=predict_relative_position_buckets,
Expand Down Expand Up @@ -1678,6 +1779,9 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs: Optional[Tuple] = None,
past_key_values=None,
inputs_embeds=None,
Expand Down Expand Up @@ -1716,6 +1820,7 @@ def forward(
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand All @@ -1728,6 +1833,8 @@ def forward(
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
Expand Down Expand Up @@ -1785,6 +1892,9 @@ def forward(
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
Expand Down Expand Up @@ -1828,6 +1938,9 @@ def forward(
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -1902,7 +2015,14 @@ def _compute_loss(self, logits, labels, ignore_index=-100):
return loss

def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."

Expand All @@ -1915,6 +2035,7 @@ def prepare_inputs_for_generation(
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"use_cache": use_cache,
}

Expand Down Expand Up @@ -1985,6 +2106,8 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
Expand All @@ -2000,6 +2123,12 @@ def forward(
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
Expand Down Expand Up @@ -2060,6 +2189,8 @@ def forward(
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down Expand Up @@ -2123,7 +2254,15 @@ def _compute_loss(self, logits, labels, ignore_index=-100):

return loss

def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
use_cache=None,
**kwargs,
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
Expand All @@ -2134,6 +2273,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"head_mask": head_mask,
"past_key_values": past,
"use_cache": use_cache,
}
Expand Down
3 changes: 0 additions & 3 deletions tests/test_modeling_prophetnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = True

def setUp(self):
Expand Down Expand Up @@ -1097,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False

def setUp(self):
Expand Down Expand Up @@ -1126,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_headmasking = False
is_encoder_decoder = False

def setUp(self):
Expand Down

0 comments on commit d967cdf

Please sign in to comment.