Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support 3D attention mask in bert #32105

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ class BertForPreTrainingOutput(ModelOutput):
[`PreTrainedTokenizer.__call__`] for details.

[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe recommend 4d here!

Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
Expand Down Expand Up @@ -1023,7 +1023,7 @@ def forward(
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

Expand Down Expand Up @@ -1093,7 +1093,7 @@ def forward(
)

# Expand the attention mask
if use_sdpa_attention_masks:
if use_sdpa_attention_masks and attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
if self.config.is_decoder:
Expand All @@ -1120,7 +1120,7 @@ def forward(
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)

if use_sdpa_attention_masks:
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
# Expand the attention mask for SDPA.
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
Expand Down
38 changes: 38 additions & 0 deletions tests/models/bert/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,14 @@ def test_model_various_embeddings(self):
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

def test_model_3d_mask_shapes(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# manipulate input_mask
config_and_inputs = list(config_and_inputs)
batch_size, seq_length = config_and_inputs[3].shape
config_and_inputs[3] = random_attention_mask([batch_size, seq_length, seq_length])
self.model_tester.create_and_check_model(*config_and_inputs)

def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
Expand Down Expand Up @@ -530,6 +538,36 @@ def test_model_as_decoder_with_default_input_mask(self):
encoder_attention_mask,
)

def test_model_as_decoder_with_3d_input_mask(self):
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = self.model_tester.prepare_config_and_inputs_for_decoder()

batch_size, seq_length = input_mask.shape
input_mask = random_attention_mask([batch_size, seq_length, seq_length])
batch_size, seq_length = encoder_attention_mask.shape
encoder_attention_mask = random_attention_mask([batch_size, seq_length, seq_length])

self.model_tester.create_and_check_model_as_decoder(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)

def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
Expand Down
Loading