Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add head_mask/decoder_head_mask for BART #9404

Closed
wants to merge 16 commits into from

Conversation

stancld
Copy link
Contributor

@stancld stancld commented Jan 4, 2021

Description:
This PR adds head_mask and decoder_head_mask for BART PyTorch implementation according to BERT implementation.

Motivation:
According to HuggingFace's websites "There is a growing field of study concerned with investigating the inner working of large-scale transformers like BERT (that some call “BERTology”)." This PR enables to mask attention heads in encoder and decoder models exactly like for BERT. This PR thus creates an opportunity to study the importance of attention heads in encoder-decoder BERT-like model.

Reviewer: @patrickvonplaten

@stancld stancld marked this pull request as draft January 5, 2021 08:39
@stancld stancld changed the title added head_mask/decoder_head_mask for BART Add head_mask/decoder_head_mask for BART Jan 5, 2021
@stancld
Copy link
Contributor Author

stancld commented Jan 7, 2021

Dear @patrickvonplaten and the rest of HuggingFace group.

I implemented the concept of head_mask from BERT into BART so that the internal of decoder-encoder-like models can be studied as well. However, as this is my very first attempt to contribute to such a large-scale open-source project, I have been a bit struggling to pass the tests. Would you be, please, able to guide me what everything needs to be done in this case in order to achieve a valid pull request?

Thank you very much for all your time in advance. I really do appreciate it.

@patrickvonplaten patrickvonplaten marked this pull request as ready for review January 7, 2021 21:07
@patrickvonplaten
Copy link
Contributor

Hi @stancld - thanks a lot for pinging me! I'm happy to help you here :-) I think you're PR is a nice addition. Sadly, we did many changes to Bart recently (see #9343) so that you'll probably have to rebase your PR to the current version of master.

@patrickvonplaten
Copy link
Contributor

After that I'm happy to get the tests passing together!

@stancld
Copy link
Contributor Author

stancld commented Jan 8, 2021

Hi @patrickvonplaten, the model should be rebased according to the commit #9343 at this moment. :) I'll be more than happy to finish this PR with you. Thanks a lot in advance :)

@patrickvonplaten
Copy link
Contributor

@stancld, please do let me know if you're stuck and need help or if your PR is ready for review, just ping me here :-)

@stancld
Copy link
Contributor Author

stancld commented Jan 11, 2021

Hi @patrickvonplaten, I would like to bring an update after the weekend off.

First of all, I would like to apologise for a bit of messy PR, as I was initially struggling with on my local (I'll do better next time).
Regarding this PR: To pass all the tests, head_mask and decoder_head_mask is now implemented for the following PyTorch BART-based models:

  • BART,
  • MBart,
  • Blenderbot,
  • BlenderbotSmall,
  • Marian,
  • Pegasus.

Besides, I think some additional tests for head_mask for these models might be desired to implement, but I leave this decision up to you. In any case, please, let me know what it needs to do to complete this PR.

@stancld
Copy link
Contributor Author

stancld commented Jan 12, 2021

@patrickvonplaten I think this PR is ready for review. I've currently resolved one conflict arose last night after a commit to master and now I've been tracking changes on my local and everything still seems to be working.

@@ -206,7 +206,12 @@ def test_forward_signature(self):
"decoder_attention_mask",
"encoder_outputs",
]
self.assertListEqual(arg_names[:5], expected_arg_names)
if model.config.model_type in ["bart", "mbart", "marian", "blenderbot", "blenderbot-small", "pegasus"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the signature arguments IMO should be as follows:

                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                    "head_mask",
                    "decoder_head_mask",
                    "encoder_outputs",
                ]

The reason is that decoder_input_ids is more important than head_mask for torchscript.

We models like Bart we would still like to be able to use torchscript as follows:

traced_bart(input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)

instead of having to do

traced_bart(input_ids, attention_mask, head_mask, decoder_input_ids, decoder_attention_mask)

where as head_mask would have to be a tensor of all 1's since in 99% of the times it's not used for torchscript.

So it'd be great if we can slightly change the order in all ...Model and all ...ForConditionalGeneration models so that we have:

                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                    "head_mask",
                    "decoder_head_mask",
                    "encoder_outputs",
                ]

head_mask is just used to little for torchscript so that we have to break the (first all encoder inputs, then all decoder inputs) logic here.

We can adapt the test as follows:

...
            arg_names = [*signature.parameters.keys()]
            if model.config.is_encoder_decoder:
                expected_arg_names = [
                    "input_ids",
                    "attention_mask",
                    "decoder_input_ids",
                    "decoder_attention_mask",
                ]

                expected_arg_names.extend(["head_mask", "decoder_head_mask", "encoder_outputs"] if "head_mask" in arg_names else ["encoder_outputs"])
                ...

@@ -395,10 +400,31 @@ def _create_and_check_torchscript(self, config, inputs_dict):
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]

traced_model = torch.jit.trace(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to not change this test. head_mask is more or less never used in torch.jit.trace. If you're keen we could overwrite this test in all Bart-like models including the head_mask, but it's not mandatory at all IMO.

However, we don't really like these if model.config.model_type not in .... statements in the tests.

@patrickvonplaten
Copy link
Contributor

Hey @stancld,

This is a super nice PR. It's very clean and that without any help - awesome!

I think there are 3 things we should change/add:

  1. I think we should change the order of the forward args of all ...Model and ...ForConditionalGeneration as explained above. This a) means that there is no breaking change in the way Bart is used with torchscript and it's the better option IMO as well since the first 4 args should always be input_ids, attention_mask, decoder_input_ids, decoder_attention_mask for EncDec models

  2. Let's try to remove all "hard-coded" model names in the common tests. I've commented above. We don't really need to test torchscript with head_mask and for the signature it'd be better to change it according to 1)

  3. It would be awesome if you could a if model.config.is_encoder_decoder part to the test_headmasking test in test_modeling_common.py that tests headmasking correctly for Seq2Seq models. To enable this test for all Bart-like models you'll have to set test_head_masking to True in BartModelTest and others. One thing we'll have to adapt in the test is we should change the line:

attentions = outputs[-1]

to

attentions = outputs.attetions

for the model.config.is_encoder_decoder is False case and to

encoder_attentions = outputs.encoder_attentions
decoder_attentions = outputs.decoder_attentions

for the other case.

I can also help you with 3) in case you're stuck.
Really impressed by how clean the PR is! Think there is not much left to do. 1) and 2) are very easy changes and 3) will require a bit more time, but should be fine as well.

@stancld
Copy link
Contributor Author

stancld commented Jan 12, 2021

Hey @patrickvonplaten, thanks a lot for your thorough feedback. I believe to come back later today with a new commit fixing the listed issues :)

* Change the order of input arguments so that first 4 args always
follows the pattern:
`input_ids, attention_mask, decoder_input_ids, decoder_attention_mask`

* Remove "hard-coded" BART-related conditions in test_modelling_common.py

* Enable test_headmasking for BART-based models. This requires to replace:
```
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
with
```
self.assertAlmostEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
to pass the test by encoder-decoder (likely caused by precision
issues since the sum above equals 0.0)
Revert term:
```
self.assertAlmostEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
back to
```
self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
```
as it was mistakenly changed

* This is the only test BART-like model cannot pass as this moment
@stancld
Copy link
Contributor Author

stancld commented Jan 12, 2021

Hey @patrickvonplaten, this PR is again ready for review after making some changes according to your notes above. The one problem at this moment is that BART-like models do not satisfy one condition in test_headmasking:

self.assertNotEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0).

I am not sure whether the formula for masking attention heads (in BART-like models) is implemented correctly. Now, if head_mask in the test case is specified as

head_mask = torch.ones(
    self.model_tester.num_hidden_layers,
    self.model_tester.num_attention_heads,
    device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0

then outputs.encoder_attentions[1][..., :, :, :] or outputs.decoder_attentions[1][..., :, :, :] equals tensor of 0.0 for all examples over all heads but the last one. This is not the case, however, for non-encoder-decoder models with attentions[1][..., :, :, :]. Do you have any idea where the problem can be?

Anyway, I hope we will solve this issue and merge this PR. :)

@stancld
Copy link
Contributor Author

stancld commented Jan 12, 2021

I made some mistakes during updating my branch, which resulted in the problem with tracking files not edited actually by myself. I find this quite inconvenient and I have failed to repair this issue so far. Therefore, I've created a new (clean) branch, which might be found here https://github.com/stancld/transformers/tree/head_mask_for_bart_new.

If you, @patrickvonplaten, were okay with that, I would close this PR (after resolving those rather minor issues raised in our discussion above) and create a new one from the new branch referenced above to make everything nice and clean before an eventual merge.

@patrickvonplaten
Copy link
Contributor

@stancld absolutely! Feel free to close this PR and open a new one :-) This happens to me all the time as well

@patrickvonplaten
Copy link
Contributor

We can just link this closed PR to the new PR to have a reference to the discussion we had

@stancld
Copy link
Contributor Author

stancld commented Jan 13, 2021

@patrickvonplaten - Great, you can find a newly open PR at #9569 :)

@stancld stancld closed this Jan 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants