Skip to content

Commit

Permalink
Added TF TransfoXL Sequence Classification (huggingface#9169)
Browse files Browse the repository at this point in the history
* TF Transfoxl seq classification

* Update test_modeling_tf_transfo_xl.py

Added num_labels to config level

* TF Transfoxl seq classification

* Update test_modeling_tf_transfo_xl.py

Added num_labels to config level

* code refactor

* code refactor

* code refator
  • Loading branch information
spatil6 authored and guyrosin committed Jan 15, 2021
1 parent e7935f5 commit 7d7783b
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@
from .models.transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
TFTransfoXLForSequenceClassification,
TFTransfoXLLMHeadModel,
TFTransfoXLMainLayer,
TFTransfoXLModel,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/auto/modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@
TFRobertaModel,
)
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from ..transfo_xl.modeling_tf_transfo_xl import TFTransfoXLLMHeadModel, TFTransfoXLModel
from ..transfo_xl.modeling_tf_transfo_xl import (
TFTransfoXLForSequenceClassification,
TFTransfoXLLMHeadModel,
TFTransfoXLModel,
)
from ..xlm.modeling_tf_xlm import (
TFXLMForMultipleChoice,
TFXLMForQuestionAnsweringSimple,
Expand Down Expand Up @@ -342,6 +346,7 @@
(GPT2Config, TFGPT2ForSequenceClassification),
(MPNetConfig, TFMPNetForSequenceClassification),
(OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification),
(TransfoXLConfig, TFTransfoXLForSequenceClassification),
(CTRLConfig, TFCTRLForSequenceClassification),
]
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/transfo_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding,
TFTransfoXLForSequenceClassification,
TFTransfoXLLMHeadModel,
TFTransfoXLMainLayer,
TFTransfoXLModel,
Expand Down
189 changes: 188 additions & 1 deletion src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, keras_serializable, shape_list
from ...modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging
from .configuration_transfo_xl import TransfoXLConfig
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
Expand Down Expand Up @@ -717,6 +724,40 @@ class TFTransfoXLLMHeadModelOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None


@dataclass
class TFTransfoXLSequenceClassifierOutputWithPast(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see :obj:`mems`
input) to speed up sequential decoding. The token ids which have their past given to this model should not
be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
mems: List[tf.Tensor] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None


TRANSFO_XL_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
Expand Down Expand Up @@ -969,3 +1010,149 @@ def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
inputs["mems"] = past

return inputs


@add_start_docstrings(
"""
The Transfo XL Model transformer with a sequence classification head on top (linear layer).
:class:`~transformers.TFTransfoXLForSequenceClassification` uses the last token in order to do the classification,
as other causal models (e.g. GPT-1,GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
""",
TRANSFO_XL_START_DOCSTRING,
)
class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.score = tf.keras.layers.Dense(
config.num_labels,
kernel_initializer=get_initializer(config.init_range),
name="score",
use_bias=False,
)
self.transformer = TFTransfoXLMainLayer(config, name="transformer")

def get_output_embeddings(self):
return self.transformer.word_emb

@add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="transfo-xl-wt103",
output_type=TFTransfoXLSequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the cross entropy classification loss. Indices should be in ``[0, ...,
config.vocab_size - 1]``.
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)

transformer_outputs = self.transformer(
input_ids=inputs["input_ids"],
mems=inputs["mems"],
head_mask=inputs["head_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if inputs["input_ids"] is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
-1,
keepdims=False,
)
- 1
)

def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)

result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
loss = None

if inputs["labels"] is not None:
if input_ids is not None:
batch_size, sequence_length = shape_list(inputs["input_ids"])[:2]
else:
batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."

if not tf.is_tensor(sequence_lengths):
in_logits = logits[0:batch_size, sequence_lengths]

loss = self.compute_loss(
tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])
)

pooled_logits = in_logits if in_logits is not None else logits

if not inputs["return_dict"]:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return TFTransfoXLSequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
9 changes: 9 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,15 @@ def __init__(self, *args, **kwargs):
requires_tf(self)


class TFTransfoXLForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)

@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)


class TFTransfoXLLMHeadModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_modeling_tf_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
if is_tf_available():
import tensorflow as tf

from transformers import TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TFTransfoXLLMHeadModel, TFTransfoXLModel
from transformers import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFTransfoXLForSequenceClassification,
TFTransfoXLLMHeadModel,
TFTransfoXLModel,
)


class TFTransfoXLModelTester:
Expand Down Expand Up @@ -55,6 +60,9 @@ def __init__(
self.scope = None
self.seed = 1
self.eos_token_id = 0
self.num_labels = 3
self.pad_token_id = self.vocab_size - 1
self.init_range = 0.01

def prepare_config_and_inputs(self):
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
Expand All @@ -77,6 +85,9 @@ def prepare_config_and_inputs(self):
div_val=self.div_val,
n_layer=self.num_hidden_layers,
eos_token_id=self.eos_token_id,
pad_token_id=self.vocab_size - 1,
init_range=self.init_range,
num_labels=self.num_labels,
)

return (config, input_ids_1, input_ids_2, lm_labels)
Expand Down Expand Up @@ -131,6 +142,11 @@ def create_and_check_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2,
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
)

def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
model = TFTransfoXLForSequenceClassification(config)
result = model(input_ids_1)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids_1, input_ids_2, lm_labels) = config_and_inputs
Expand All @@ -141,7 +157,9 @@ def prepare_config_and_inputs_for_common(self):
@require_tf
class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):

all_model_classes = (TFTransfoXLModel, TFTransfoXLLMHeadModel) if is_tf_available() else ()
all_model_classes = (
(TFTransfoXLModel, TFTransfoXLLMHeadModel, TFTransfoXLForSequenceClassification) if is_tf_available() else ()
)
all_generative_model_classes = () if is_tf_available() else ()
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
test_resize_embeddings = False
Expand All @@ -163,6 +181,10 @@ def test_transfo_xl_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_lm_head(*config_and_inputs)

def test_transfo_xl_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)

def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down

0 comments on commit 7d7783b

Please sign in to comment.