Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EncoderDecoder] Add Cross Attention for GPT2 #6415

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 11, 2020

This PR implements Bert2GPT2 by adding cross-attention layers to GPT2.

Note that currently it is not possible to speed up decoder generation with the encoder-decoder framework (by using GPT2's past tensors) since it has to be implemented for all models that are compatible with the encoder/decoder framework (Bert, Roberta) before it can be used within the framework.

All GPT2 RUN_SLOW tests are verified to pass.

Future PRs TODO:

  • Verify that Bert2GPT2 works by training on CNN Daily Mail summarization
  • Add smart caching to Bert and add it to the encoder-decoder framework
  • Update encoder-decoder docs
  • Add a notebook explaining how to use encoder-decoder models.

@patrickvonplaten patrickvonplaten force-pushed the add_gpt2_encoder_decoder branch from 7e7c8ad to ad5af2c Compare August 13, 2020 07:05
@patrickvonplaten patrickvonplaten changed the title [WIP][EncoderDecoder] Add Cross Attention for GPT2 [EncoderDecoder] Add Cross Attention for GPT2 Aug 13, 2020
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks great to me!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks good to me!

patrickvonplaten and others added 2 commits August 14, 2020 09:25
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Aug 14, 2020

Codecov Report

Merging #6415 into master will decrease coverage by 0.00%.
The diff coverage is 96.61%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6415      +/-   ##
==========================================
- Coverage   79.98%   79.98%   -0.01%     
==========================================
  Files         153      153              
  Lines       28005    28039      +34     
==========================================
+ Hits        22401    22427      +26     
- Misses       5604     5612       +8     
Impacted Files Coverage Δ
src/transformers/modeling_encoder_decoder.py 91.66% <87.50%> (+0.64%) ⬆️
src/transformers/modeling_gpt2.py 86.68% <97.87%> (+0.71%) ⬆️
src/transformers/generation_utils.py 96.94% <100.00%> (+0.01%) ⬆️
src/transformers/modeling_tf_distilbert.py 64.47% <0.00%> (-32.95%) ⬇️
src/transformers/generation_tf_utils.py 86.71% <0.00%> (+7.51%) ⬆️
src/transformers/modeling_tf_flaubert.py 87.73% <0.00%> (+63.19%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bc82047...56094e2. Read the comment docs.

@patrickvonplaten patrickvonplaten merged commit 1d6e71e into huggingface:master Aug 14, 2020
sgugger added a commit that referenced this pull request Aug 14, 2020
* Generation doc

* MBartForConditionalGeneration (#6441)

* add MBartForConditionalGeneration

* style

* rebase and fixes

* add mbart test in TEST_FILES_WITH_NO_COMMON_TESTS

* fix docs

* don't ignore mbart

* doc

* fix mbart fairseq link

* put mbart before bart

* apply doc suggestions

* Use hash to clean the test dirs (#6475)

* Use hash to clean the test dirs

* Use hash to clean the test dirs

* Use hash to clean the test dirs

* fix

* [EncoderDecoder] Add Cross Attention for GPT2 (#6415)

* add cross attention layers for gpt2

* make gpt2 cross attention work

* finish bert2gpt2

* add explicit comments

* remove attention mask since not yet supported

* revert attn mask in pipeline

* Update src/transformers/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sort unique_no_split_tokens to make it deterministic (#6461)

* change unique_no_split_tokens's type to set

* use sorted list instead of set

* style

* Import accuracy_score (#6480)

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address comments

* Styling

* Generation doc

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Address comments

* Styling

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Kevin Canwen Xu <canwenxu@126.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: gijswijnholds <gijswijnholds@gmail.com>
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
@patrickvonplaten patrickvonplaten mentioned this pull request Oct 25, 2022
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants