-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support 3D attention mask in bert #32105
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Let's document that as well and good to merge 🤗 (I think you might have some make fix copied to handle)
Hi I updated the doc string. But I'm not familiar with how the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi @ArthurZucker , could you give another review on the updated code? Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I think this is more support arbitrary mask shape
, as if the shape is not == 2 we are just not expanding the mask!
The mask will be broadcasted to 4d eventually, which is why I would rather make sure our inputs are 4d or 2d TBH! But otherwise LGTM
Thanks for the comment @ArthurZucker. The use case is that Grounding DINO uses a 3D attention mask. Right now, all implementations have to creating a wrapper of |
I just want to make sure we "clarrify", as the 3D mask you are providing will be added to the attention scores, which shape cannot be other than 3d, meaning there is an implicit broadcast. Having a check like for Llama would be better IMO! |
Could you give a reference of the check in Llama, I can try to do the same |
MB, in transformers/src/transformers/models/llama/modeling_llama.py Lines 59 to 109 in 834ec7b
|
For BERT it will go to some common functions which makes it more like the PR #31302 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok let's just recommend 4d over 3d
@@ -908,7 +908,7 @@ class BertForPreTrainingOutput(ModelOutput): | |||
[`PreTrainedTokenizer.__call__`] for details. | |||
|
|||
[What are input IDs?](../glossary#input-ids) | |||
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): | |||
attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's maybe recommend 4d here!
I just realize that if we use 4D attention mask, there will be an error at https://github.com/gathierry/transformers/blob/97903a6352fdc897e13bc4f221fc57aa73e2697b/src/transformers/models/bert/modeling_bert.py#L1113 |
yeah we should not got down that code path. 4d mask should be untouched |
Yeah but seems like 3D is compatible by |
Okay, let's go with 3d then 😉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ready to merge?
Yes, thanks |
thanks for your contribution! 🤗 |
* support 3D/4D attention mask in bert * test cases * update doc * fix doc
* support 3D/4D attention mask in bert * test cases * update doc * fix doc
* support 3D/4D attention mask in bert * test cases * update doc * fix doc
* support 3D/4D attention mask in bert * test cases * update doc * fix doc
What does this PR do?
Fixes #31036
Try to fix with minimal change as mentioned in #31302
3D attention masks are used in GroundingDINO
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.