-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split LMBert model in two #4874
Changes from 3 commits
1765e83
0604146
904004a
181e0fa
7524001
a78b5b7
f37914b
020d8e5
56a698f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -852,9 +852,10 @@ def forward( | |
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) | ||
|
||
|
||
# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder. | ||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) | ||
class BertForMaskedLM(BertPreTrainedModel): | ||
@add_start_docstrings( | ||
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING | ||
) | ||
class BertLMHeadModel(BertPreTrainedModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
|
@@ -878,17 +879,129 @@ def forward( | |
labels=None, | ||
encoder_hidden_states=None, | ||
encoder_attention_mask=None, | ||
lm_labels=None, | ||
**kwargs | ||
): | ||
r""" | ||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
Labels for computing the masked language modeling loss. | ||
Labels for computing the left-to-right language modeling loss (next word prediction). | ||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) | ||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels | ||
in ``[0, ..., config.vocab_size]`` | ||
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
Labels for computing the left-to-right language modeling loss (next word prediction). | ||
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): | ||
Used to hide legacy arguments that have been deprecated. | ||
|
||
Returns: | ||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: | ||
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): | ||
Next token prediction loss. | ||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) | ||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | ||
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | ||
|
||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | ||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): | ||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape | ||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. | ||
|
||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | ||
heads. | ||
|
||
Examples:: | ||
|
||
from transformers import BertTokenizer, BertLMHeadModel | ||
import torch | ||
|
||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | ||
model = BertLMHeadModel.from_pretrained('bert-base-uncased') | ||
|
||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 | ||
outputs = model(input_ids, labels=input_ids) | ||
|
||
loss, prediction_scores = outputs[:2] | ||
|
||
""" | ||
|
||
outputs = self.bert( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
token_type_ids=token_type_ids, | ||
position_ids=position_ids, | ||
head_mask=head_mask, | ||
inputs_embeds=inputs_embeds, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
) | ||
|
||
sequence_output = outputs[0] | ||
prediction_scores = self.cls(sequence_output) | ||
|
||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here | ||
|
||
if labels is not None: | ||
# we are doing next-token prediction; shift prediction scores and input ids by one | ||
prediction_scores = prediction_scores[:, :-1, :].contiguous() | ||
labels = labels[:, 1:].contiguous() | ||
loss_fct = CrossEntropyLoss() | ||
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) | ||
outputs = (ltr_lm_loss,) + outputs | ||
|
||
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions) | ||
|
||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): | ||
input_shape = input_ids.shape | ||
effective_batch_size = input_shape[0] | ||
|
||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly | ||
if attention_mask is None: | ||
attention_mask = input_ids.new_ones(input_shape) | ||
|
||
# if model is does not use a causal mask then add a dummy token | ||
if self.config.is_decoder is False: | ||
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation" | ||
attention_mask = torch.cat( | ||
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1 | ||
) | ||
|
||
dummy_token = torch.full( | ||
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device | ||
) | ||
input_ids = torch.cat([input_ids, dummy_token], dim=1) | ||
|
||
return {"input_ids": input_ids, "attention_mask": attention_mask} | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) | ||
class BertForMaskedLM(BertPreTrainedModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
self.bert = BertModel(config) | ||
self.cls = BertOnlyMLMHead(config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could also do a check here that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't raise an exception here before the automodel in encoder_decode picks the right LM model. (Or the tests will fail.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, we could do this in a separate PR |
||
|
||
self.init_weights() | ||
|
||
def get_output_embeddings(self): | ||
return self.cls.predictions.decoder | ||
|
||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) | ||
def forward( | ||
self, | ||
input_ids=None, | ||
attention_mask=None, | ||
token_type_ids=None, | ||
position_ids=None, | ||
head_mask=None, | ||
inputs_embeds=None, | ||
labels=None, | ||
encoder_hidden_states=None, | ||
encoder_attention_mask=None, | ||
**kwargs | ||
): | ||
r""" | ||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): | ||
Labels for computing the masked language modeling loss. | ||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) | ||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels | ||
in ``[0, ..., config.vocab_size]`` | ||
|
@@ -899,8 +1012,6 @@ def forward( | |
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: | ||
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: | ||
Masked language modeling loss. | ||
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided): | ||
Next token prediction loss. | ||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) | ||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | ||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): | ||
|
@@ -935,6 +1046,7 @@ def forward( | |
DeprecationWarning, | ||
) | ||
labels = kwargs.pop("masked_lm_labels") | ||
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task." | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." | ||
|
||
outputs = self.bert( | ||
|
@@ -953,26 +1065,12 @@ def forward( | |
|
||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here | ||
|
||
# Although this may seem awkward, BertForMaskedLM supports two scenarios: | ||
# 1. If a tensor that contains the indices of masked labels is provided, | ||
# the cross-entropy is the MLM cross-entropy that measures the likelihood | ||
# of predictions for masked words. | ||
# 2. If `lm_labels` is provided we are in a causal scenario where we | ||
# try to predict the next token for each input in the decoder. | ||
if labels is not None: | ||
loss_fct = CrossEntropyLoss() # -100 index = padding token | ||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) | ||
outputs = (masked_lm_loss,) + outputs | ||
|
||
if lm_labels is not None: | ||
# we are doing next-token prediction; shift prediction scores and input ids by one | ||
prediction_scores = prediction_scores[:, :-1, :].contiguous() | ||
lm_labels = lm_labels[:, 1:].contiguous() | ||
loss_fct = CrossEntropyLoss() | ||
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) | ||
outputs = (ltr_lm_loss,) + outputs | ||
|
||
return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions) | ||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) | ||
|
||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
input_shape = input_ids.shape | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,7 +70,6 @@ def prepare_config_and_inputs_bert(self): | |
"decoder_token_labels": decoder_token_labels, | ||
"decoder_choice_labels": decoder_choice_labels, | ||
"encoder_hidden_states": encoder_hidden_states, | ||
"lm_labels": decoder_token_labels, | ||
"labels": decoder_token_labels, | ||
} | ||
|
||
|
@@ -288,38 +287,6 @@ def create_and_check_bert_encoder_decoder_model_labels( | |
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) | ||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,))) | ||
|
||
def create_and_check_bert_encoder_decoder_model_lm_labels( | ||
self, | ||
config, | ||
input_ids, | ||
attention_mask, | ||
encoder_hidden_states, | ||
decoder_config, | ||
decoder_input_ids, | ||
decoder_attention_mask, | ||
lm_labels, | ||
**kwargs | ||
): | ||
encoder_model = BertModel(config) | ||
decoder_model = BertForMaskedLM(decoder_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we change to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) | ||
enc_dec_model.to(torch_device) | ||
outputs_encoder_decoder = enc_dec_model( | ||
input_ids=input_ids, | ||
decoder_input_ids=decoder_input_ids, | ||
attention_mask=attention_mask, | ||
decoder_attention_mask=decoder_attention_mask, | ||
lm_labels=lm_labels, | ||
) | ||
|
||
lm_loss = outputs_encoder_decoder[0] | ||
self.check_loss_output(lm_loss) | ||
# check that backprop works | ||
lm_loss.backward() | ||
|
||
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))) | ||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,))) | ||
|
||
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): | ||
encoder_model = BertModel(config) | ||
decoder_model = BertForMaskedLM(decoder_config) | ||
|
@@ -356,10 +323,6 @@ def test_bert_encoder_decoder_model_labels(self): | |
input_ids_dict = self.prepare_config_and_inputs_bert() | ||
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict) | ||
|
||
def test_bert_encoder_decoder_model_lm_labels(self): | ||
input_ids_dict = self.prepare_config_and_inputs_bert() | ||
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict) | ||
|
||
def test_bert_encoder_decoder_model_generate(self): | ||
input_ids_dict = self.prepare_config_and_inputs_bert() | ||
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the init should have an assert statement now to make sure
config.is_decoder=True
or even setconfig.is_decoder=True
with a logging info that it does so. UsingBertLMHeadModel
without a causal mask does not make much sense IMO.Also, currently
BertLayer
always initializes the cross-attention layer ifconfig.is_decoder=True
, but doesn't use it if noencoder_hidden_states
are passed to forward => so there is not really a problem atm when using the model on its own (not in an encoder-decoder framework). BUT, this should be corrected with an additional checkself.config.is_encoder_decoder
before adding the cross-attention layers.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added an assert.