-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add head_mask/decoder_head_mask for BART #9404
Conversation
Dear @patrickvonplaten and the rest of HuggingFace group. I implemented the concept of Thank you very much for all your time in advance. I really do appreciate it. |
After that I'm happy to get the tests passing together! |
Hi @patrickvonplaten, the model should be rebased according to the commit #9343 at this moment. :) I'll be more than happy to finish this PR with you. Thanks a lot in advance :) |
@stancld, please do let me know if you're stuck and need help or if your PR is ready for review, just ping me here :-) |
Models: MBart, Marian, Blenberbot, BlenderbotSmall
Hi @patrickvonplaten, I would like to bring an update after the weekend off. First of all, I would like to apologise for a bit of messy PR, as I was initially struggling with on my local (I'll do better next time).
Besides, I think some additional tests for head_mask for these models might be desired to implement, but I leave this decision up to you. In any case, please, let me know what it needs to do to complete this PR. |
@patrickvonplaten I think this PR is ready for review. I've currently resolved one conflict arose last night after a commit to |
tests/test_modeling_common.py
Outdated
@@ -206,7 +206,12 @@ def test_forward_signature(self): | |||
"decoder_attention_mask", | |||
"encoder_outputs", | |||
] | |||
self.assertListEqual(arg_names[:5], expected_arg_names) | |||
if model.config.model_type in ["bart", "mbart", "marian", "blenderbot", "blenderbot-small", "pegasus"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order of the signature arguments IMO should be as follows:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"encoder_outputs",
]
The reason is that decoder_input_ids
is more important than head_mask
for torchscript.
We models like Bart we would still like to be able to use torchscript as follows:
traced_bart(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
instead of having to do
traced_bart(input_ids, attention_mask, head_mask, decoder_input_ids, decoder_attention_mask)
where as head_mask
would have to be a tensor of all 1's since in 99% of the times it's not used for torchscript.
So it'd be great if we can slightly change the order in all ...Model
and all ...ForConditionalGeneration
models so that we have:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"encoder_outputs",
]
head_mask
is just used to little for torchscript so that we have to break the (first all encoder inputs, then all decoder inputs) logic here.
We can adapt the test as follows:
...
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(["head_mask", "decoder_head_mask", "encoder_outputs"] if "head_mask" in arg_names else ["encoder_outputs"])
...
@@ -395,10 +400,31 @@ def _create_and_check_torchscript(self, config, inputs_dict): | |||
attention_mask = inputs["attention_mask"] | |||
decoder_input_ids = inputs["decoder_input_ids"] | |||
decoder_attention_mask = inputs["decoder_attention_mask"] | |||
|
|||
traced_model = torch.jit.trace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to not change this test. head_mask
is more or less never used in torch.jit.trace
. If you're keen we could overwrite this test in all Bart-like models including the head_mask, but it's not mandatory at all IMO.
However, we don't really like these if model.config.model_type not in ....
statements in the tests.
Hey @stancld, This is a super nice PR. It's very clean and that without any help - awesome! I think there are 3 things we should change/add:
to attentions = outputs.attetions for the encoder_attentions = outputs.encoder_attentions
decoder_attentions = outputs.decoder_attentions for the other case. I can also help you with 3) in case you're stuck. |
Hey @patrickvonplaten, thanks a lot for your thorough feedback. I believe to come back later today with a new commit fixing the listed issues :) |
* Change the order of input arguments so that first 4 args always follows the pattern: `input_ids, attention_mask, decoder_input_ids, decoder_attention_mask` * Remove "hard-coded" BART-related conditions in test_modelling_common.py * Enable test_headmasking for BART-based models. This requires to replace: ``` self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) ``` with ``` self.assertAlmostEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) ``` to pass the test by encoder-decoder (likely caused by precision issues since the sum above equals 0.0)
…formers into head_mask_for_bart
Revert term: ``` self.assertAlmostEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) ``` back to ``` self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0) ``` as it was mistakenly changed * This is the only test BART-like model cannot pass as this moment
Hey @patrickvonplaten, this PR is again ready for review after making some changes according to your notes above. The one problem at this moment is that BART-like models do not satisfy one condition in
I am not sure whether the formula for masking attention heads (in BART-like models) is implemented correctly. Now, if
then Anyway, I hope we will solve this issue and merge this PR. :) |
I made some mistakes during updating my branch, which resulted in the problem with tracking files not edited actually by myself. I find this quite inconvenient and I have failed to repair this issue so far. Therefore, I've created a new (clean) branch, which might be found here https://github.com/stancld/transformers/tree/head_mask_for_bart_new. If you, @patrickvonplaten, were okay with that, I would close this PR (after resolving those rather minor issues raised in our discussion above) and create a new one from the new branch referenced above to make everything nice and clean before an eventual merge. |
@stancld absolutely! Feel free to close this PR and open a new one :-) This happens to me all the time as well |
We can just link this closed PR to the new PR to have a reference to the discussion we had |
@patrickvonplaten - Great, you can find a newly open PR at #9569 :) |
Description:
This PR adds
head_mask
anddecoder_head_mask
for BART PyTorch implementation according to BERT implementation.Motivation:
According to HuggingFace's websites "There is a growing field of study concerned with investigating the inner working of large-scale transformers like BERT (that some call “BERTology”)." This PR enables to mask attention heads in encoder and decoder models exactly like for BERT. This PR thus creates an opportunity to study the importance of attention heads in encoder-decoder BERT-like model.
Reviewer: @patrickvonplaten