Skip to content

Commit

Permalink
Add pretraining loss computation for TF Bert pretraining (#8470)
Browse files Browse the repository at this point in the history
* Add pretraining loss computation for TF Bert pretraining

* Fix labels creation

* Fix T5 model

* restore T5 kwargs

* try a generic fix for pretraining models

* Apply style

* Overide the prepare method for the BERT tests
  • Loading branch information
jplu authored Nov 12, 2020
1 parent 91a67b7 commit 5d80539
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 13 deletions.
84 changes: 79 additions & 5 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@
]


class TFBertPreTrainingLoss:
"""
Loss function suitable for BERT-like pre-training, that is, the task of pretraining a language model by combining
NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
computation.
"""

def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(labels["labels"], (-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tf.reshape(logits[0], (-1, shape_list(logits[0])[2])),
masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(tf.reshape(labels["labels"], (-1,)), masked_lm_active_loss)
next_sentence_active_loss = tf.not_equal(tf.reshape(labels["next_sentence_label"], (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits[1], (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(
tf.reshape(labels["next_sentence_label"], (-1,)), mask=next_sentence_active_loss
)
masked_lm_loss = loss_fn(masked_lm_labels, masked_lm_reduced_logits)
next_sentence_loss = loss_fn(next_sentence_label, next_sentence_reduced_logits)
masked_lm_loss = tf.reshape(masked_lm_loss, (-1, shape_list(next_sentence_loss)[0]))
masked_lm_loss = tf.reduce_mean(masked_lm_loss, 0)

return masked_lm_loss + next_sentence_loss


class TFBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""

Expand Down Expand Up @@ -688,6 +720,7 @@ class TFBertForPreTrainingOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
prediction_logits: tf.Tensor = None
seq_relationship_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
Expand Down Expand Up @@ -814,7 +847,7 @@ def call(self, inputs, **kwargs):
""",
BERT_START_DOCSTRING,
)
class TFBertForPreTraining(TFBertPreTrainedModel):
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand All @@ -827,7 +860,21 @@ def get_output_embeddings(self):

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
next_sentence_label=None,
training=False,
):
r"""
Return:
Expand All @@ -843,17 +890,44 @@ def call(self, inputs, **kwargs):
>>> prediction_scores, seq_relationship_scores = outputs[:2]
"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)

if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)

outputs = self.bert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
prediction_scores = self.mlm(sequence_output, training=training)
seq_relationship_score = self.nsp(pooled_output)
total_loss = None

if labels is not None and next_sentence_label is not None:
d_labels = {"labels": labels}
d_labels["next_sentence_label"] = next_sentence_label
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))

if not return_dict:
return (prediction_scores, seq_relationship_score) + outputs[2:]

return TFBertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
if is_tf_available():
import tensorflow as tf

from transformers import TF_MODEL_FOR_PRETRAINING_MAPPING
from transformers.modeling_tf_bert import (
TFBertForMaskedLM,
TFBertForMultipleChoice,
Expand Down Expand Up @@ -274,6 +275,16 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
else ()
)

# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)

if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)

return inputs_dict

def setUp(self):
self.model_tester = TFBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
Expand Down
27 changes: 19 additions & 8 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -102,6 +103,7 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(),
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
]:
inputs_dict["labels"] = tf.zeros(
Expand Down Expand Up @@ -834,7 +836,9 @@ def test_loss_computation(self):
if getattr(model, "compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)

if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
Expand All @@ -859,23 +863,30 @@ def test_loss_computation(self):

# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.getfullargspec(model.call)[0]
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())

# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {1: "input_ids"}
tuple_index_mapping = {0: "input_ids"}
for label_key in label_keys:
label_key_index = signature.index(label_key)
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []

for name in signature_names:
if name != "kwargs":
list_input.append(signature[name].default)

# Initialize a list with None, update the values and convert to a tuple
list_input = [None] * sorted_tuple_index_mapping[-1][0]
for index, value in sorted_tuple_index_mapping:
list_input[index - 1] = prepared_for_class[value]
list_input[index] = prepared_for_class[value]

tuple_input = tuple(list_input)

# Send to model
loss = model(tuple_input)[0]
loss = model(tuple_input[:-1])[0]

self.assertEqual(loss.shape, [loss_size])

def _generate_random_bad_tokens(self, num_bad_tokens, model):
Expand Down

0 comments on commit 5d80539

Please sign in to comment.