-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trying to add support for GPT2 as decoder in EncoderDecoder model #4483
Comments
@dimi1357 out of curiosity, what does training this look like? |
This is my training loop: x, encoder_attention_mask, y, decoder_attention_mask, _ = batch
x = x.to(self.device)
y = y.to(self.device)
encoder_attention_mask = encoder_attention_mask.to(self.device)
decoder_attention_mask = decoder_attention_mask.to(self.device)
model_kwargs = {
"attention_mask": encoder_attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"lm_labels": y
}
self.optimizer.zero_grad()
outputs = self.model(input_ids=x, decoder_input_ids=y, **model_kwargs)
loss = outputs[0]
loss.backward()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step() and I create the model this way:
|
@dimi1357 Did you finally make it work? Can you provide me the "full changes" in some way? I am also interested in using the GPT2 model as decoder. |
Thanks for the Feature request and the in-detail code! I will think a bit more about how to implement this and get back to you! |
I forgot to add the change I've made to def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
encoder_attention_mask=None):
output_attn = self.attn(
self.ln_1(x),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
a = output_attn[0] # output_attn: a, present, (attentions)
outputs = []
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
a, layer_past, attention_mask, head_mask, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask
)
a = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
outputs = [x] + output_attn[1:] + outputs
return outputs # x, present, (attentions) |
You can add the code above to where you've installed the transformers package, but I'm still not sure that this implementation is correct, so I suggest you wait for an update from huggingface team if this is okay. |
Hey @dimi1357 . So I think the Encoder Decoder roadmap is as follows:
I will keep your code sample here in mind for this :-) |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Hi, |
Hey, I will take a look at BERTGPT2 encoder-decoder probably on Monday next week |
@patrickvonplaten Can you please share a work in progress notebook/colab, or some code. I am willing to help with tests and datasets, in order to improve the BERT2GPT2 model. Thank you :D |
Will finish the PR tomorrow then it should be pretty easy to do BERT2GPT2. |
Hi @patrickvonplaten . I've used your latest commit to train BERT2GPT2 using your BERT2BERT training tutorial. It was straight forward, I only had to replace the "bert" from decoder with "gpt2". The training worked, but at inference time there was a code error in
I do not know if the model requires a different evaluation approach. |
Thanks for the implementation, I'm going to test it now. |
GPT2 is added and results on summariation look promising. Check out this model (Bert2GPT2 trained on CNN/Daily Mail) including train and eval script: https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16 . |
Hi @patrickvonplaten, I used this model card to train on my custom dataset, but again the TypeError is been thrownback that
If you can see it carefully you can find that an argument is missing in |
You need to switch to this branch: https://github.com/huggingface/transformers/tree/more_general_trainer_metric to make the training work. I am trying to integrate this branch into master soon :-) |
Thanks for letting me know. |
Sorry to ask a question after a long period of time :-). I am still not very clear about the effect of encoder attention mask in GPT2. I understand that it is used only in the decoder of Encoder-Decoder model to make some change to the cross attention weights. Also, I notice the operation defined in the modelling_gpt2.py: However, I am confused why we need this encoder attention mask. Is that also because the decoder can not see the whole sequence? Thanks for help :-) |
@AmbiTyga @patrickvonplaten Is this error fixed? I have switched to the branch "more_general_trainer_metric." But it seems this error still exists when I am running codes in https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16. |
The code is a bit outdated there. You should be able to simply use the https://github.com/huggingface/transformers/tree/master/examples/pytorch/summarization example. In order to create a BERT2GPT2 checkpoint, you could a code that is similar to this one: https://huggingface.co/docs/transformers/v4.17.0/en/model_doc/encoder-decoder#transformers.EncoderDecoderModel.forward (just replace one BERT by GPT2) So to summarize,
I'll keep this issue open for now since we should probably create a nice "How-to" guide for this |
Thanks for your guidance! I try this method to create and ft a bert2gpt2 model, but it seems that "tokenizer" would be a problem: I can't load a single suitable tokenizer for this model in the summarization example. So is it necessary for me to defined tokenizer1 for bert and tokenizer2 for gpt2 and then change any code that is related to "tokenizer" in order to fix this problem? @patrickvonplaten |
It's fine to load two tokenizers no? |
Yeah,I use 2 tokenizers to replace "tokenizer" in run_summarization.py and also do some other changes, the code can work now(although I don't know whether it is right....). Here are my changes.
|
Hey everyone, |
Can I work on this issue as a good first issue or is there no point? |
I don't think there is any point @Forpee |
You can just use the current implementation described in the docs:
Why do you want to do that, given the cited performance reduction? |
I am trying to train it on question generation task to compare the results |
🚀 Feature request
Hi,
I am trying to add the option of using GPT2 as the decoder in the EncoderDecoder model, which only support
Motivation
For a generation problem, it usually better to use GPT2 as the decoder, over BERT.
Your contribution
I've made the following changes in
modeling_gpt2.py
file:Block
class:encoder_attention_mask
andencoder_hidden_states
to the forward function of theAttention
class, and using them for the key and the value if they are provided:encoder_attention_mask
andencoder_hidden_states
arguments to theGPT2Model
forward function, and processedencoder_attention_mask
same as attention_mask:encoder_attention_mask
andencoder_hidden_states
arguments to theGPT2LMHeadModel
forward function, as well aslm_lables
andmasked_lm_labels
for EncoderDecoder model compatibility (probably it's better to useGPT2DoubleHeadsModel
):My biggest concern is with the second bullet, and I wanted to ask you if this implementation seems right (for now it's look like I am able to train and test an EncoderDecoder with BERT2GPT architecture).
Of course that if needed, I can provide the full code to all of my changes, but all of my changes is listed above.
Most (if not all) of the code I've add is adapted from huggingface
modeling_bert.py
file, so all of the credit goes to them.Thanks
The text was updated successfully, but these errors were encountered: