Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EncoderDecoder] Add encoder-decoder for roberta/ vanilla longformer #6411

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 11, 2020

This PR adds Roberta to the Encoder Decoder framework. Thus, it automatically makes it possible to use both Roberta2Roberta models and Longformer2Roberta model:

from transformers import EncoderDecoderModel
model = EncoderDecoderModel.from_pretrained("roberta-base", "roberta-base")
input_ids = torch.tensor([10 * [0]])
model(input_ids=input_ids, decoder_input_ids=input_ids)

and

from transformers import EncoderDecoderModel
model = EncoderDecoderModel.from_pretrained("allenai/longformer-base-4096", "roberta-base")
input_ids = torch.tensor([10 * [0]])
model(input_ids=input_ids, decoder_input_ids=input_ids)

Also pinging @ibeltagy and @patil-suraj

@patrickvonplaten patrickvonplaten changed the title [EncoderDecoder] Add encoder-decoder for roberta [EncoderDecoder] Add encoder-decoder for roberta/ vanilla longformer Aug 11, 2020
@codecov
Copy link

codecov bot commented Aug 11, 2020

Codecov Report

Merging #6411 into master will decrease coverage by 1.93%.
The diff coverage is 92.85%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6411      +/-   ##
==========================================
- Coverage   79.77%   77.84%   -1.94%     
==========================================
  Files         150      150              
  Lines       27789    27826      +37     
==========================================
- Hits        22170    21660     -510     
- Misses       5619     6166     +547     
Impacted Files Coverage Δ
src/transformers/__init__.py 99.25% <ø> (ø)
src/transformers/modeling_auto.py 63.95% <ø> (-14.54%) ⬇️
src/transformers/modeling_encoder_decoder.py 91.02% <ø> (ø)
src/transformers/modeling_bert.py 88.42% <50.00%> (ø)
src/transformers/modeling_tf_bert.py 96.22% <50.00%> (-0.36%) ⬇️
src/transformers/modeling_roberta.py 95.98% <97.36%> (+0.20%) ⬆️
src/transformers/optimization.py 28.94% <0.00%> (-67.11%) ⬇️
src/transformers/modeling_tf_flaubert.py 24.53% <0.00%> (-63.20%) ⬇️
src/transformers/pipelines.py 26.98% <0.00%> (-52.81%) ⬇️
src/transformers/optimization_tf.py 33.33% <0.00%> (-24.33%) ⬇️
... and 18 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4047829...ed8414a. Read the comment docs.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see the encoder/decoder framework expanded! Thanks for all the work!

src/transformers/modeling_roberta.py Outdated Show resolved Hide resolved
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using those kwargs? If so change the docstrings since there is no legacy arguments here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed them from the corresponding BERT model as well.

tests/test_modeling_bert.py Show resolved Hide resolved
Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Have people had good ROUGE with the compose two pretrained glue models and finetune for summarization approach?

@patrickvonplaten
Copy link
Contributor Author

LGTM. Have people had good ROUGE with the compose two pretrained glue models and finetune for summarization approach?

Hmm, I think it's very new so not sure if many people have tried out the framework yet. @patil-suraj - do you know if people work a lot with EncoderDecoder by chance?

@patil-suraj
Copy link
Contributor

do you know if people work a lot with EncoderDecoder by chance?

Seems like it, seen quite a few issues and questions (on forum as well) regarding EncoderDecoder, but no one has reported any good results yet

@ibeltagy
Copy link
Contributor

ibeltagy commented Aug 11, 2020

Looks great. Thanks, @patrickvonplaten.

LGTM. Have people had good ROUGE with the compose two pretrained glue models and finetune for summarization approach?

@sshleifer, was thinking about the same thing. My guess is that numbers won't be great because cross-attention is randomly initialized?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM!

@add_start_docstrings(
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
)
class RobertaForCausalLM(BertPreTrainedModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be more coherent to have it as RobertaLMHeadModel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But I do prefer RobertaForCausalLM)

Copy link
Collaborator

@sgugger sgugger Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same names for BERT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following internal discussion will leave the name as it is more precise and BertLMHeadModel should change in the future.

tests/test_modeling_bert.py Show resolved Hide resolved
Comment on lines -979 to +983
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

Comment on lines 236 to 237
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this I think

@patrickvonplaten patrickvonplaten merged commit 0735def into huggingface:master Aug 12, 2020
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Aug 17, 2020

Looks great. Thanks, @patrickvonplaten.

LGTM. Have people had good ROUGE with the compose two pretrained glue models and finetune for summarization approach?

@sshleifer, was thinking about the same thing. My guess is that numbers won't be great because cross-attention is randomly initialized?

Btw, this paper does some great analysis on reusing checkpoints for Seq2Seq models: https://arxiv.org/pdf/1907.12461.pdf

fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants