Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split LMBert model in two #4874

Merged
merged 9 commits into from
Jun 10, 2020
Merged

Split LMBert model in two #4874

merged 9 commits into from
Jun 10, 2020

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Jun 9, 2020

As discussed in #4711, the BertForMaskedLM model should be split in two to avoid having two different labels argument, one model for causal LM, one for masked LM. This PR follows up on that and does the split.

It introduces a new BertLMHeadModel (also added to the __init__ and the docs) with a test. As discussed, there is no deprecation warning if someone tries to use the lm_labels in BertForMaskedLM (since it was experimental), but an error message telling the user to use BertLMHeadModel.

I did not add BertLMHeadModel in the automodel logic since we probably want users to use causal models for this? Let me know if I should add it even if it's not the best model for that task.

I also removed lm_labels in the EncoderDecoderModel since it was only there to support that argument in BertForMaskedLM (which then removes the corresponding test).

@clmnt clmnt requested review from julien-c and thomwolf June 9, 2020 14:03
@codecov
Copy link

codecov bot commented Jun 9, 2020

Codecov Report

Merging #4874 into master will decrease coverage by 0.68%.
The diff coverage is 88.88%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4874      +/-   ##
==========================================
- Coverage   77.11%   76.43%   -0.69%     
==========================================
  Files         128      128              
  Lines       21651    21671      +20     
==========================================
- Hits        16697    16564     -133     
- Misses       4954     5107     +153     
Impacted Files Coverage Δ
src/transformers/modeling_encoder_decoder.py 92.20% <ø> (ø)
src/transformers/modeling_bert.py 88.17% <88.88%> (-0.12%) ⬇️
src/transformers/modeling_tf_pytorch_utils.py 8.72% <0.00%> (-81.21%) ⬇️
src/transformers/modeling_ctrl.py 96.56% <0.00%> (-2.58%) ⬇️
src/transformers/modeling_xlnet.py 76.31% <0.00%> (-2.31%) ⬇️
src/transformers/modeling_openai.py 79.51% <0.00%> (-1.39%) ⬇️
src/transformers/trainer.py 38.26% <0.00%> (-1.18%) ⬇️
src/transformers/benchmark/benchmark_utils.py 72.80% <0.00%> (-0.30%) ⬇️
src/transformers/file_utils.py 73.79% <0.00%> (+0.40%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3ae2e86...56a698f. Read the comment docs.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is cool, it simplifies the BertForMaskedLM model and makes the API more coherent.

However, this is a breaking change. The model now behaves differently, with part of its functions being removed.

I'm not sure that breaking change is worth it.

src/transformers/modeling_bert.py Show resolved Hide resolved
@sgugger
Copy link
Collaborator Author

sgugger commented Jun 9, 2020

It's possible to do it in a non-breaking way with a deprecation warning if a not-None labels_lm is passed to BertForMaskedLM. I was following the discussion of #4711 that implied it was okay to have a breaking change for this.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 9, 2020

I'm fine with this PR. IMO, BertForMaskedLM was never really used before for causal language modeling except when using Bert in an encoder-decoder setting and the encoder-decoder code is not really released yet. Also since we keep the same names for the submodules self.bert and self.cls, there won't be any errors or inconsistencies when loading pre-trained weights into the Bert2Bert encoder-decoder.

In my opinion, this change is necessary to have a clean separation between masked lm and causal lm (Reformer and Longformer will eventually run into the same issue).

The heavily used BertForMaskedLM for the normal masked encoder bert model does not change at all except for lm_labels, so that's good in terms of backward compatibility.

One thing which is problematic though is that the MODEL_WITH_LM_HEAD_MAPPING contains a mixture of causal models and masked encoder models at the moment:

MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(

Now since Bert has both a causal model and a masked encoder model we need two mappings.

I would suggest here to create 2 new mappings MODEL_FOR_MASKED_LM_MAPPING and MODEL_FOR_CAUSAL_LM_MAPPING and two new AutoModels: AutoModelForMaksedLM , AutoModelForCausalLM and for now keep AutoModelWithLMHead as it is and add a depreciated warning to it.

We can add BertLMHeadModel to MODEL_FOR_CAUSAL_LM_MAPPING and change to AutoModelForCausalLM in the encoder-decoder model. Also @thomwolf and @julien-c here

@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the init should have an assert statement now to make sure config.is_decoder=True or even set config.is_decoder=True with a logging info that it does so. Using BertLMHeadModel without a causal mask does not make much sense IMO.

Also, currently BertLayer always initializes the cross-attention layer if config.is_decoder=True, but doesn't use it if no encoder_hidden_states are passed to forward => so there is not really a problem atm when using the model on its own (not in an encoder-decoder framework). BUT, this should be corrected with an additional check self.config.is_encoder_decoder before adding the cross-attention layers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assert.

super().__init__(config)

self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could also do a check here that self.config.is_decoder == False.

Copy link
Collaborator Author

@sgugger sgugger Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't raise an exception here before the automodel in encoder_decode picks the right LM model. (Or the tests will fail.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, we could do this in a separate PR

**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change to BertLMHeadModel in all other tests here.
I'm fine with this one being deleted since I don't really see an application where one would train an encoder-decoder model on the masked lm objective.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Note that the tests using automodel pick the wrong decoder for now.

@sgugger
Copy link
Collaborator Author

sgugger commented Jun 9, 2020

I agree with @patrickvonplaten on the need to split AutoModelWithLMHead in two. Note that if the name AutoModelForCausalLM is picked, we should then rename (with a deprecation first of course) all ModeltypeLMHeadModel to ModeltypeForCausalLM for consistency (and clarity since just saying it has an LM head doesn't tell us if it's intended to be masked or causal).

@LysandreJik
Copy link
Member

LysandreJik commented Jun 9, 2020

I agree that having two additional AutoXXX classes for the distinction between masked/causal would be nice. We should, however, keep the AutoModelWithLMHead class available for backwards compatibility.

I don't agree with renaming all causal model with language modeling heads XXXForCausalLM. It would be more consistent, but is an aesthetic change with a very big breaking change. Even adding aliases to keep backwards compatibility would create a large overhead for the user, in my opinion, as all those classes would exist twice when importing from the library.

@sgugger
Copy link
Collaborator Author

sgugger commented Jun 9, 2020

In that case I would advocate to keep AutoModelWithLMHead for causal language models and only add an AutoModelForMaskedLM. Consistency is cosmetic, I agree, but it also helps not confusing beginners.

@patrickvonplaten
Copy link
Contributor

  1. For now, I think the best solution would be to keep AutoModelForMaskedLM as it is and add two new AutoXXX classes. The EncoderDecoderModel would be the first model to use AutoModelForCausalLM in its code.

AutoModelWithLMHead is heavily used for all kinds of masked bert encoder models, so if we create an AutoModelForMaskedLM and move BertForMaskedLM there, we would have a lot of breaking change. I think we could add a depreciation warning to AutoModelWithLMHead though.

  1. I'm a bit indifferent to renaming all other model classes. While I'm also a big fan of consistency I agree with @LysandreJik in that I think it's a big user-facing API change that is not really urgent atm.

@julien-c
Copy link
Member

julien-c commented Jun 9, 2020

In the short term, I would advocate only exposing the classical "masked-lm" flavour of BERT through AutoModelWithLMHead (as is done in this PR), and not even documenting/adding BertLMHeadModel to the __init__, as it's only used as a building block to other models.

In the longer term, I'd be ok with creating AutoModelFor{Masked,Causal}LM (name TBD for the second one) and not even creating a deprecation for AutoModelWithLMHead, forcing users to explicitly choose one or the other. This would need to be a major release though.

@LysandreJik
Copy link
Member

LysandreJik commented Jun 9, 2020

@julien-c as long as we do a major release for the AutoModel renaming, I'm all for this!

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

read this, very nice!

@patrickvonplaten
Copy link
Contributor

In the short term, I would advocate only exposing the classical "masked-lm" flavour of BERT through AutoModelWithLMHead (as is done in this PR), and not even documenting/adding BertLMHeadModel to the __init__, as it's only used as a building block to other models.

In the longer term, I'd be ok with creating AutoModelFor{Masked,Causal}LM (name TBD for the second one) and not even creating a deprecation for AutoModelWithLMHead, forcing users to explicitly choose one or the other. This would need to be a major release though.

For the encoder decoder models, I think we need BertLMHeadModel in the init and we would also need a AutoModelForCausalLM. Here:

from .modeling_auto import AutoModelWithLMHead
we need to instantiate a BertWithCausalLM

@thomwolf
Copy link
Member

I'm fine either way, I think you guys got all the important issues (backward compatibility versus cleanly building the future). I like what @patrickvonplaten and @julien-c are proposing.

@sgugger
Copy link
Collaborator Author

sgugger commented Jun 10, 2020

Fixed conflicts and followed @julien-c advice. @LysandreJik or @patrickvonplaten, could you do one final review just to make sure everything is fine to merge?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not forget to write an explanation of why that breaking change was necessary in the release notes.

@sgugger sgugger merged commit 1e2631d into huggingface:master Jun 10, 2020
@sgugger sgugger deleted the split_bert_lm branch June 10, 2020 22:26
@patrickvonplaten
Copy link
Contributor

This currently breaks the encoder-decoder framework from_encoder_decoder_pretrained() method. Will do a PR tomorrow to fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants