Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support 3D attention mask in bert #32105

Merged
merged 5 commits into from
Sep 6, 2024

Conversation

gathierry
Copy link
Contributor

What does this PR do?

Fixes #31036
Try to fix with minimal change as mentioned in #31302

3D attention masks are used in GroundingDINO

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Let's document that as well and good to merge 🤗 (I think you might have some make fix copied to handle)

@gathierry
Copy link
Contributor Author

Hi I updated the doc string. But I'm not familiar with how the attention_mask doc is overwritten. If I did it in a wrong way, could you help indicate where should be change?
make fix-copies seems fine, but maybe I miss something.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gathierry
Copy link
Contributor Author

Hi @ArthurZucker , could you give another review on the updated code? Thanks

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think this is more support arbitrary mask shape, as if the shape is not == 2 we are just not expanding the mask!

The mask will be broadcasted to 4d eventually, which is why I would rather make sure our inputs are 4d or 2d TBH! But otherwise LGTM

@gathierry
Copy link
Contributor Author

Thanks for the comment @ArthurZucker. The use case is that Grounding DINO uses a 3D attention mask. Right now, all implementations have to creating a wrapper of BertModel. Accepting 3D mask is more straightforward but I'm also okay with the idea about only accepting 2d or 4d. Grounding DINO has to expand its attention mask beforehand though.
Do you want an assert in the beginning to make sure the input shape is either 2D or 4D?

@ArthurZucker
Copy link
Collaborator

I just want to make sure we "clarrify", as the 3D mask you are providing will be added to the attention scores, which shape cannot be other than 3d, meaning there is an implicit broadcast. Having a check like for Llama would be better IMO!

@gathierry
Copy link
Contributor Author

Could you give a reference of the check in Llama, I can try to do the same

@ArthurZucker
Copy link
Collaborator

MB, in

def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
we don't raise and error so let's go with that

@gathierry
Copy link
Contributor Author

For BERT it will go to some common functions which makes it more like the PR #31302

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's just recommend 4d over 3d

@@ -908,7 +908,7 @@ class BertForPreTrainingOutput(ModelOutput):
[`PreTrainedTokenizer.__call__`] for details.

[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's maybe recommend 4d here!

@gathierry
Copy link
Contributor Author

@ArthurZucker
Copy link
Collaborator

yeah we should not got down that code path. 4d mask should be untouched

@gathierry
Copy link
Contributor Author

gathierry commented Aug 30, 2024

Yeah but seems like 3D is compatible by get_extended_attention_mask from the beginning?

@ArthurZucker
Copy link
Collaborator

Okay, let's go with 3d then 😉

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ready to merge?

@gathierry
Copy link
Contributor Author

Yes, thanks

@ArthurZucker ArthurZucker merged commit 342e800 into huggingface:main Sep 6, 2024
21 checks passed
@ArthurZucker
Copy link
Collaborator

thanks for your contribution! 🤗

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Sep 6, 2024
* support 3D/4D attention mask in bert

* test cases

* update doc

* fix doc
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* support 3D/4D attention mask in bert

* test cases

* update doc

* fix doc
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* support 3D/4D attention mask in bert

* test cases

* update doc

* fix doc
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* support 3D/4D attention mask in bert

* test cases

* update doc

* fix doc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sdpa for bert should support 4D attention mask.
3 participants