Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split LMBert model in two #4874

Merged
merged 9 commits into from
Jun 10, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/model_doc/bert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ BertForPreTraining
:members:


BertLMHeadModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BertLMHeadModel
:members:


BertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
BertForMultipleChoice,
BertForTokenClassification,
BertForQuestionAnswering,
BertLMHeadModel,
load_tf_weights_in_bert,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertLayer,
Expand Down
146 changes: 122 additions & 24 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,9 +852,10 @@ def forward(
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)


# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder.
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the init should have an assert statement now to make sure config.is_decoder=True or even set config.is_decoder=True with a logging info that it does so. Using BertLMHeadModel without a causal mask does not make much sense IMO.

Also, currently BertLayer always initializes the cross-attention layer if config.is_decoder=True, but doesn't use it if no encoder_hidden_states are passed to forward => so there is not really a problem atm when using the model on its own (not in an encoder-decoder framework). BUT, this should be corrected with an additional check self.config.is_encoder_decoder before adding the cross-attention layers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assert.

def __init__(self, config):
super().__init__(config)

Expand All @@ -878,17 +879,129 @@ def forward(
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
lm_labels=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction).
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.

Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.

Examples::

from transformers import BertTokenizer, BertLMHeadModel
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertLMHeadModel.from_pretrained('bert-base-uncased')

input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)

loss, prediction_scores = outputs[:2]

"""

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (ltr_lm_loss,) + outputs

return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# if model is does not use a causal mask then add a dummy token
if self.config.is_decoder is False:
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
)

dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)

return {"input_ids": input_ids, "attention_mask": attention_mask}
sgugger marked this conversation as resolved.
Show resolved Hide resolved


@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could also do a check here that self.config.is_decoder == False.

Copy link
Collaborator Author

@sgugger sgugger Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't raise an exception here before the automodel in encoder_decode picks the right LM model. (Or the tests will fail.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, we could do this in a separate PR


self.init_weights()

def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Expand All @@ -899,8 +1012,6 @@ def forward(
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Expand Down Expand Up @@ -935,6 +1046,7 @@ def forward(
DeprecationWarning,
)
labels = kwargs.pop("masked_lm_labels")
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
sgugger marked this conversation as resolved.
Show resolved Hide resolved
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

outputs = self.bert(
Expand All @@ -953,26 +1065,12 @@ def forward(

outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If `lm_labels` is provided we are in a causal scenario where we
# try to predict the next token for each input in the decoder.
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs

if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (ltr_lm_loss,) + outputs

return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
sgugger marked this conversation as resolved.
Show resolved Hide resolved
input_shape = input_ids.shape
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def forward(
decoder_head_mask=None,
decoder_inputs_embeds=None,
labels=None,
lm_labels=None,
**kwargs,
):

Expand Down Expand Up @@ -239,11 +238,6 @@ def forward(
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction) for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
Expand Down Expand Up @@ -293,7 +287,6 @@ def forward(
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
lm_labels=lm_labels,
labels=labels,
**kwargs_decoder,
)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BertForSequenceClassification,
BertForTokenClassification,
BertForMultipleChoice,
BertLMHeadModel,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST

Expand Down Expand Up @@ -211,6 +212,24 @@ def create_and_check_bert_model_as_decoder(
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])

def create_and_check_bert_for_autoregressive_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)

def create_and_check_bert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
Expand Down Expand Up @@ -460,6 +479,10 @@ def test_bert_model_as_decoder_with_default_input_mask(self):
encoder_attention_mask,
)

def test_for_autoregressive_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_autoregressive_lm(*config_and_inputs)

def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
Expand Down
37 changes: 0 additions & 37 deletions tests/test_modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def prepare_config_and_inputs_bert(self):
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"lm_labels": decoder_token_labels,
"labels": decoder_token_labels,
}

Expand Down Expand Up @@ -288,38 +287,6 @@ def create_and_check_bert_encoder_decoder_model_labels(
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))

def create_and_check_bert_encoder_decoder_model_lm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change to BertLMHeadModel in all other tests here.
I'm fine with this one being deleted since I don't really see an application where one would train an encoder-decoder model on the masked lm objective.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
Note that the tests using automodel pick the wrong decoder for now.

enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)

lm_loss = outputs_encoder_decoder[0]
self.check_loss_output(lm_loss)
# check that backprop works
lm_loss.backward()

self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))

def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
Expand Down Expand Up @@ -356,10 +323,6 @@ def test_bert_encoder_decoder_model_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)

def test_bert_encoder_decoder_model_lm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)

def test_bert_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)
Expand Down