diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 7b7a3cc81f09da..ab353b3343601e 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -104,6 +104,24 @@ decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will also be used by default. + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, @@ -146,6 +164,12 @@ - 0 for tokens that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ + head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -633,6 +657,7 @@ def forward( hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, past_key_value: Optional[Tuple[Tensor]] = None, output_attentions: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -708,6 +733,19 @@ def forward( attn_weights_reshaped = None attn_weights = F.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_attn_heads, + ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + batch_size, self.num_attn_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len) + + # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model + attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped + attn_probs = F.dropout( attn_weights, p=self.attention_dropout, @@ -797,6 +835,7 @@ def forward( hidden_states, past_key_value: Optional[Tuple[Tensor]] = None, attention_mask=None, + layer_head_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, @@ -876,6 +915,15 @@ def forward( onnx_trace=self.onnx_trace, ).type_as(main_attn_weights) + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_attn_heads, + ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" + main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( + batch_size, self.num_attn_heads, -1, sequence_length + ) + main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length) + main_attn_probs = F.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) # project to attn_output main_attn_output = torch.bmm(main_attn_probs, main_value_states) @@ -929,6 +977,18 @@ def forward( dim=-1, onnx_trace=self.onnx_trace, ).type_as(predict_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == ( + self.num_attn_heads, + ), f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is {layer_head_mask.size()}" + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view( + self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length + ) + predict_attn_probs = predict_attn_probs.view( + self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length + ) + predict_attn_probs = F.dropout(predict_attn_probs, p=self.attention_dropout, training=self.training) # project to attention output # [ngram, B*head, T, c] @@ -1063,11 +1123,18 @@ def __init__(self, config: ProphetNetConfig): self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) - def forward(self, hidden_states, attention_mask, output_attentions: bool = False): + def forward( + self, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions: bool = False, + ): # 1st residual block attention_output, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) @@ -1110,6 +1177,8 @@ def forward( attention_mask=None, encoder_hidden_states=None, encoder_attn_mask=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, extended_predict_attention_mask=None, main_relative_position_buckets=None, predict_relative_position_buckets=None, @@ -1125,6 +1194,7 @@ def forward( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, + layer_head_mask=layer_head_mask, extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, @@ -1141,6 +1211,7 @@ def forward( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, + layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) @@ -1202,6 +1273,7 @@ def forward( self, input_ids=None, attention_mask=None, + head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, @@ -1254,7 +1326,12 @@ def forward( encoder_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None - for encoder_layer in self.layers: + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == ( + len(self.layers) + ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) @@ -1270,10 +1347,14 @@ def custom_forward(*inputs): create_custom_forward(encoder_layer), hidden_states, extended_attention_mask, + (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( - hidden_states, attention_mask=extended_attention_mask, output_attentions=output_attentions + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] @@ -1338,6 +1419,8 @@ def forward( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, use_cache=None, @@ -1352,6 +1435,12 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -1460,6 +1549,12 @@ def forward( all_cross_attns = () if output_attentions and self.config.add_cross_attention else None present_key_values = () if use_cache else None + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == ( + len(self.layers) + ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: # grad cannot be kept because tensor is sliced @@ -1491,6 +1586,8 @@ def custom_forward(*inputs): extended_attention_mask, encoder_hidden_states, extended_encoder_attention_mask, + (head_mask[idx] if head_mask is not None else None), + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), extended_predict_attention_mask, main_relative_position_buckets, predict_relative_position_buckets, @@ -1503,6 +1600,10 @@ def custom_forward(*inputs): attention_mask=extended_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), extended_predict_attention_mask=extended_predict_attention_mask, main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, @@ -1678,6 +1779,9 @@ def forward( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs: Optional[Tuple] = None, past_key_values=None, inputs_embeds=None, @@ -1716,6 +1820,7 @@ def forward( encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, + head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1728,6 +1833,8 @@ def forward( attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, @@ -1785,6 +1892,9 @@ def forward( attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, @@ -1828,6 +1938,9 @@ def forward( attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, encoder_outputs=encoder_outputs, past_key_values=past_key_values, inputs_embeds=inputs_embeds, @@ -1902,7 +2015,14 @@ def _compute_loss(self, logits, labels, ignore_index=-100): return loss def prepare_inputs_for_generation( - self, decoder_input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, ): assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." @@ -1915,6 +2035,7 @@ def prepare_inputs_for_generation( "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "head_mask": head_mask, "use_cache": use_cache, } @@ -1985,6 +2106,8 @@ def forward( attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, labels=None, @@ -2000,6 +2123,12 @@ def forward( encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. @@ -2060,6 +2189,8 @@ def forward( attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -2123,7 +2254,15 @@ def _compute_loss(self, logits, labels, ignore_index=-100): return loss - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + use_cache=None, + **kwargs, + ): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) @@ -2134,6 +2273,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=Non return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "attention_mask": attention_mask, + "head_mask": head_mask, "past_key_values": past, "use_cache": use_cache, } diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index 7314f6f4147b0e..caeb8413130ad0 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -891,7 +891,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test test_pruning = False test_torchscript = False test_resize_embeddings = False - test_headmasking = False is_encoder_decoder = True def setUp(self): @@ -1097,7 +1096,6 @@ class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMix test_pruning = False test_torchscript = False test_resize_embeddings = False - test_headmasking = False is_encoder_decoder = False def setUp(self): @@ -1126,7 +1124,6 @@ class ProphetNetStandaloneEncoderModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_torchscript = False test_resize_embeddings = False - test_headmasking = False is_encoder_decoder = False def setUp(self):