Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add T5ForSequenceClassification #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.utils.checkpoint import checkpoint

from ...activations import ACT2FN
from ...file_utils import (
add_code_sample_docstrings,
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
Expand All @@ -39,6 +40,7 @@
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
Seq2SeqSequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
Expand Down Expand Up @@ -1824,3 +1826,152 @@ def forward(
)

return encoder_outputs


class T5ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""

def __init__(
self,
input_dim: int,
inner_dim: int,
num_classes: int,
pooler_dropout: float,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)

def forward(self, hidden_states: torch.Tensor):
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states


@add_start_docstrings(
"""
T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
T5_START_DOCSTRING,
)
class T5ForSequenceClassification(T5PreTrainedModel):
def __init__(self, config: T5Config, **kwargs):
super().__init__(config, **kwargs)
self.model = T5EncoderModel(config)
self.classification_head = T5ClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)

@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False

if input_ids is None and inputs_embeds is not None:
raise NotImplementedError(
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
)

outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0] # last hidden state

eos_mask = input_ids.eq(self.config.eos_token_id)

if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :
]
logits = self.classification_head(sentence_representation)

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)