Skip to content

Commit

Permalink
Split LMBert model in two (#4874)
Browse files Browse the repository at this point in the history
* Split LMBert model in two

* Fix example

* Remove lm_labels

* Adapt tests, refactor prepare_for_generation

* Fix merge

* Hide BeartLMHeadModel
  • Loading branch information
sgugger authored Jun 10, 2020
1 parent f6da8b2 commit 1e2631d
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 95 deletions.
158 changes: 119 additions & 39 deletions src/transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,11 +873,13 @@ def forward(
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)


# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder.
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."

self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)
Expand All @@ -899,18 +901,119 @@ def forward(
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
lm_labels=None,
output_attentions=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction).
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import BertTokenizer, BertLMHeadModel
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertLMHeadModel.from_pretrained('bert-base-uncased', is_decoder=True)
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, prediction_scores = outputs[:2]
"""

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (ltr_lm_loss,) + outputs

return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

return {"input_ids": input_ids, "attention_mask": attention_mask}


@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config)

self.init_weights()

def get_output_embeddings(self):
return self.cls.predictions.decoder

@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Expand All @@ -921,8 +1024,6 @@ def forward(
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Expand Down Expand Up @@ -957,6 +1058,7 @@ def forward(
DeprecationWarning,
)
labels = kwargs.pop("masked_lm_labels")
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."

outputs = self.bert(
Expand All @@ -976,46 +1078,24 @@ def forward(

outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here

# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If `lm_labels` is provided we are in a causal scenario where we
# try to predict the next token for each input in the decoder.
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
outputs = (masked_lm_loss,) + outputs

if lm_labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (ltr_lm_loss,) + outputs

return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)

def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# if model is does not use a causal mask then add a dummy token
if self.config.is_decoder is False:
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
)

dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
# add a dummy token
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)

return {"input_ids": input_ids, "attention_mask": attention_mask}

Expand Down
7 changes: 0 additions & 7 deletions src/transformers/modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def forward(
decoder_head_mask=None,
decoder_inputs_embeds=None,
labels=None,
lm_labels=None,
**kwargs,
):

Expand Down Expand Up @@ -239,11 +238,6 @@ def forward(
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction) for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
Expand Down Expand Up @@ -293,7 +287,6 @@ def forward(
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
lm_labels=lm_labels,
labels=labels,
**kwargs_decoder,
)
Expand Down
41 changes: 36 additions & 5 deletions tests/test_modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
BertForTokenClassification,
BertForMultipleChoice,
)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel


class BertModelTester:
Expand Down Expand Up @@ -211,6 +211,33 @@ def create_and_check_bert_model_as_decoder(
)
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])

def create_and_check_bert_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
self.parent.assertListEqual(
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
)
self.check_loss_output(result)

def create_and_check_bert_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
Expand All @@ -229,7 +256,7 @@ def create_and_check_bert_for_masked_lm(
)
self.check_loss_output(result)

def create_and_check_bert_model_for_masked_lm_as_decoder(
def create_and_check_bert_model_for_causal_lm_as_decoder(
self,
config,
input_ids,
Expand All @@ -241,7 +268,7 @@ def create_and_check_bert_model_for_masked_lm_as_decoder(
encoder_hidden_states,
encoder_attention_mask,
):
model = BertForMaskedLM(config=config)
model = BertLMHeadModel(config=config)
model.to(torch_device)
model.eval()
loss, prediction_scores = model(
Expand Down Expand Up @@ -461,13 +488,17 @@ def test_bert_model_as_decoder_with_default_input_mask(self):
encoder_attention_mask,
)

def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs)

def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)

def test_for_masked_lm_decoder(self):
def test_for_causal_lm_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs)
self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs)

def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
Expand Down
Loading

0 comments on commit 1e2631d

Please sign in to comment.