Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5ForSequenceClassification #14097

Open
MetcalfeTom opened this issue Oct 21, 2021 · 27 comments
Open

T5ForSequenceClassification #14097

MetcalfeTom opened this issue Oct 21, 2021 · 27 comments

Comments

@MetcalfeTom
Copy link

🚀 Feature request

T5 to classify sequences by using only the encoder of T5 and a ClassificationHead.

Motivation

This gives the benefits of fine-tuning a model with no maximum sequence length (useful for long sequence tasks) without having to load the decoder weights into memory/treat it as a generative task.

Your contribution

I already have working code for this, and saw some requests for it in other forums (slack, torch, huggingface) so if it's a welcome addition I'd be happy to add it to the library.

@prajjwal1
Copy link
Contributor

T5ForMultipleChoice would also be very helpful.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@minmaxme
Copy link

This seems like a useful addition, especially considering the EncT5 paper

@subhalingamd
Copy link
Contributor

any update on this?

@LysandreJik LysandreJik reopened this Jan 21, 2022
@LysandreJik
Copy link
Member

Maybe of interest to @NielsRogge

@stefan-it
Copy link
Collaborator

Token Classification would also be very interesting when I think of evaluations for Big Science project.

@stefan-it
Copy link
Collaborator

But w.r.t. sequence classification, shouldn't it be similar to the sequence classification that is used for the BART model, as seen here:

class BartForSequenceClassification(BartPretrainedModel):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BartClassificationHead(
config.d_model,
config.d_model,
config.num_labels,
config.classifier_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if input_ids is None and inputs_embeds is not None:
raise NotImplementedError(
f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
)
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, :
]
logits = self.classification_head(sentence_representation)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)

🤔

@subhalingamd
Copy link
Contributor

subhalingamd commented Jan 24, 2022

Hi. I have done this at subhalingamd@82db59d. But the list of uninitialized weights doesn't seem convincing at all. Here is the list for reference (for t5-small):

Some weights of the model checkpoint at t5-small were not used when initializing T5ForSequenceClassification: ['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'encoder.block.4.layer.1.layer_norm.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'encoder.block.4.layer.1.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.EncDecAttention.v.weight', 'encoder.block.0.layer.1.layer_norm.weight', 'encoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.o.weight', 'encoder.block.1.layer.1.DenseReluDense.wo.weight', 'decoder.block.2.layer.1.EncDecAttention.k.weight', 'encoder.block.2.layer.1.DenseReluDense.wo.weight', 'decoder.final_layer_norm.weight', 'decoder.block.2.layer.1.EncDecAttention.q.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.0.SelfAttention.o.weight', 'encoder.block.0.layer.1.DenseReluDense.wi.weight', 'encoder.block.4.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.o.weight', 'encoder.block.1.layer.0.SelfAttention.v.weight', 'encoder.block.1.layer.0.SelfAttention.k.weight', 'decoder.block.3.layer.0.layer_norm.weight', 'encoder.block.1.layer.0.layer_norm.weight', 'encoder.block.4.layer.0.layer_norm.weight', 'decoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.1.EncDecAttention.v.weight', 'encoder.block.2.layer.0.layer_norm.weight', 'encoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.1.EncDecAttention.k.weight', 'encoder.block.3.layer.1.DenseReluDense.wo.weight', 'encoder.block.3.layer.1.layer_norm.weight', 'encoder.block.0.layer.0.layer_norm.weight', 'decoder.block.2.layer.1.layer_norm.weight', 'decoder.block.2.layer.2.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.v.weight', 'encoder.final_layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'encoder.block.5.layer.1.layer_norm.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'encoder.block.2.layer.0.SelfAttention.k.weight', 'encoder.block.3.layer.0.layer_norm.weight', 'encoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.3.layer.1.EncDecAttention.k.weight', 'encoder.block.2.layer.0.SelfAttention.v.weight', 'encoder.block.4.layer.0.SelfAttention.q.weight', 'encoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.1.EncDecAttention.q.weight', 'decoder.block.2.layer.1.EncDecAttention.v.weight', 'encoder.block.2.layer.0.SelfAttention.o.weight', 'encoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'encoder.block.0.layer.0.SelfAttention.v.weight', 'encoder.block.5.layer.1.DenseReluDense.wo.weight', 'decoder.block.1.layer.0.SelfAttention.k.weight', 'encoder.block.0.layer.1.DenseReluDense.wo.weight', 'decoder.block.1.layer.1.layer_norm.weight', 'encoder.block.1.layer.0.SelfAttention.q.weight', 'encoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'encoder.block.1.layer.1.layer_norm.weight', 'encoder.block.0.layer.0.SelfAttention.k.weight', 'encoder.block.2.layer.1.DenseReluDense.wi.weight', 'decoder.block.2.layer.2.DenseReluDense.wo.weight', 'encoder.block.5.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.3.layer.1.EncDecAttention.q.weight', 'encoder.block.3.layer.0.SelfAttention.q.weight', 'decoder.block.2.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.0.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight', 'encoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.0.layer.1.EncDecAttention.o.weight', 'decoder.block.1.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.2.DenseReluDense.wo.weight', 'shared.weight', 'encoder.block.5.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.4.layer.0.SelfAttention.q.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.3.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.1.EncDecAttention.k.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.3.layer.1.EncDecAttention.o.weight', 'encoder.block.1.layer.1.DenseReluDense.wi.weight', 'encoder.block.4.layer.1.DenseReluDense.wi.weight', 'decoder.block.5.layer.2.DenseReluDense.wo.weight', 'encoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.0.SelfAttention.o.weight', 'decoder.block.5.layer.0.layer_norm.weight', 'encoder.block.4.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.1.layer_norm.weight', 'decoder.block.3.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.2.layer_norm.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.1.layer.1.EncDecAttention.v.weight', 'encoder.block.5.layer.1.DenseReluDense.wi.weight', 'encoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.5.layer.1.layer_norm.weight', 'decoder.block.5.layer.1.EncDecAttention.v.weight', 'encoder.block.3.layer.1.DenseReluDense.wi.weight', 'decoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.4.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.2.DenseReluDense.wo.weight', 'decoder.block.5.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.2.layer.2.DenseReluDense.wi.weight', 'decoder.block.1.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.0.layer.1.EncDecAttention.k.weight', 'encoder.block.2.layer.0.SelfAttention.q.weight', 'decoder.block.1.layer.1.EncDecAttention.q.weight', 'encoder.block.5.layer.0.layer_norm.weight', 'decoder.block.0.layer.1.EncDecAttention.q.weight', 'decoder.block.1.layer.0.layer_norm.weight', 'decoder.block.4.layer.0.SelfAttention.k.weight', 'encoder.block.2.layer.1.layer_norm.weight', 'decoder.block.0.layer.0.SelfAttention.k.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.1.layer.2.DenseReluDense.wi.weight', 'decoder.block.3.layer.0.SelfAttention.v.weight', 'decoder.block.4.layer.2.layer_norm.weight', 'decoder.block.2.layer.0.layer_norm.weight']
- This IS expected if you are initializing T5ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at t5-small and are newly initialized: ['model.encoder.block.5.layer.0.SelfAttention.k.weight', 'model.decoder.block.4.layer.1.EncDecAttention.k.weight', 'model.decoder.block.4.layer.0.SelfAttention.k.weight', 'model.decoder.block.4.layer.2.DenseReluDense.wi.weight', 'model.decoder.block.3.layer.1.EncDecAttention.v.weight', 'model.decoder.block.0.layer.1.EncDecAttention.v.weight', 'model.decoder.block.5.layer.0.SelfAttention.v.weight', 'model.encoder.block.4.layer.0.SelfAttention.v.weight', 'model.decoder.block.5.layer.0.SelfAttention.q.weight', 'model.decoder.block.2.layer.1.EncDecAttention.k.weight', 'model.encoder.block.1.layer.0.SelfAttention.q.weight', 'model.encoder.block.2.layer.1.DenseReluDense.wo.weight', 'model.encoder.block.3.layer.0.SelfAttention.q.weight', 'model.decoder.block.0.layer.1.EncDecAttention.k.weight', 'model.decoder.block.3.layer.2.DenseReluDense.wi.weight', 'model.encoder.block.4.layer.0.SelfAttention.k.weight', 'model.encoder.block.4.layer.0.layer_norm.weight', 'model.decoder.block.0.layer.0.layer_norm.weight', 'model.decoder.block.4.layer.1.EncDecAttention.v.weight', 'model.decoder.block.3.layer.1.EncDecAttention.k.weight', 'model.decoder.block.2.layer.0.SelfAttention.v.weight', 'model.encoder.block.0.layer.1.layer_norm.weight', 'model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'model.decoder.block.1.layer.0.SelfAttention.o.weight', 'model.decoder.block.1.layer.0.layer_norm.weight', 'model.encoder.block.4.layer.1.DenseReluDense.wo.weight', 'model.decoder.block.0.layer.0.SelfAttention.k.weight', 'model.decoder.block.0.layer.2.DenseReluDense.wo.weight', 'model.decoder.block.5.layer.1.EncDecAttention.o.weight', 'model.decoder.block.2.layer.0.SelfAttention.o.weight', 'model.decoder.block.3.layer.0.SelfAttention.v.weight', 'model.decoder.block.0.layer.1.EncDecAttention.o.weight', 'classification_head.out_proj.weight', 'model.decoder.block.2.layer.0.SelfAttention.q.weight', 'model.decoder.block.4.layer.0.SelfAttention.v.weight', 'model.encoder.block.3.layer.0.SelfAttention.v.weight', 'model.encoder.block.5.layer.0.SelfAttention.o.weight', 'model.decoder.block.2.layer.0.SelfAttention.k.weight', 'model.decoder.block.3.layer.1.EncDecAttention.o.weight', 'model.encoder.block.2.layer.0.SelfAttention.v.weight', 'model.decoder.block.1.layer.2.DenseReluDense.wi.weight', 'model.decoder.block.4.layer.1.layer_norm.weight', 'model.decoder.block.1.layer.1.EncDecAttention.q.weight', 'model.decoder.block.5.layer.1.EncDecAttention.q.weight', 'model.decoder.block.5.layer.2.DenseReluDense.wi.weight', 'model.encoder.embed_tokens.weight', 'model.encoder.block.1.layer.1.DenseReluDense.wo.weight', 'model.encoder.block.1.layer.1.DenseReluDense.wi.weight', 'model.encoder.block.3.layer.0.layer_norm.weight', 'model.encoder.block.0.layer.1.DenseReluDense.wo.weight', 'model.decoder.block.3.layer.2.layer_norm.weight', 'model.encoder.block.5.layer.1.layer_norm.weight', 'model.encoder.block.3.layer.1.DenseReluDense.wo.weight', 'model.encoder.block.4.layer.1.layer_norm.weight', 'model.decoder.block.3.layer.0.layer_norm.weight', 'model.encoder.block.0.layer.0.SelfAttention.k.weight', 'model.decoder.block.2.layer.0.layer_norm.weight', 'model.decoder.block.0.layer.1.EncDecAttention.q.weight', 'model.encoder.block.4.layer.0.SelfAttention.o.weight', 'model.decoder.block.2.layer.2.layer_norm.weight', 'model.decoder.block.2.layer.2.DenseReluDense.wo.weight', 'model.decoder.block.0.layer.0.SelfAttention.q.weight', 'model.decoder.block.0.layer.0.SelfAttention.v.weight', 'model.encoder.final_layer_norm.weight', 'model.encoder.block.0.layer.0.layer_norm.weight', 'model.encoder.block.3.layer.0.SelfAttention.k.weight', 'model.encoder.block.0.layer.0.SelfAttention.q.weight', 'model.encoder.block.2.layer.1.DenseReluDense.wi.weight', 'model.decoder.final_layer_norm.weight', 'model.decoder.block.4.layer.2.DenseReluDense.wo.weight', 'model.decoder.block.3.layer.0.SelfAttention.q.weight', 'model.encoder.block.2.layer.0.SelfAttention.o.weight', 'model.decoder.block.3.layer.0.SelfAttention.k.weight', 'model.encoder.block.0.layer.0.SelfAttention.o.weight', 'model.encoder.block.4.layer.1.DenseReluDense.wi.weight', 'model.decoder.block.0.layer.2.layer_norm.weight', 'model.decoder.block.1.layer.1.layer_norm.weight', 'model.encoder.block.5.layer.1.DenseReluDense.wo.weight', 'model.encoder.block.1.layer.0.SelfAttention.v.weight', 'model.decoder.block.1.layer.0.SelfAttention.v.weight', 'model.encoder.block.1.layer.1.layer_norm.weight', 'classification_head.dense.bias', 'model.decoder.block.2.layer.1.layer_norm.weight', 'model.decoder.block.3.layer.1.EncDecAttention.q.weight', 'model.decoder.block.5.layer.2.DenseReluDense.wo.weight', 'model.encoder.block.3.layer.1.layer_norm.weight', 'model.decoder.block.1.layer.1.EncDecAttention.k.weight', 'model.decoder.block.0.layer.0.SelfAttention.o.weight', 'model.encoder.block.2.layer.0.SelfAttention.q.weight', 'model.decoder.block.5.layer.0.SelfAttention.o.weight', 'model.decoder.block.4.layer.0.SelfAttention.o.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.block.2.layer.2.DenseReluDense.wi.weight', 'model.encoder.block.3.layer.0.SelfAttention.o.weight', 'model.encoder.block.0.layer.1.DenseReluDense.wi.weight', 'model.encoder.block.1.layer.0.SelfAttention.o.weight', 'model.decoder.block.0.layer.1.layer_norm.weight', 'model.decoder.block.3.layer.0.SelfAttention.o.weight', 'classification_head.dense.weight', 'model.encoder.block.5.layer.0.SelfAttention.v.weight', 'model.decoder.block.5.layer.1.EncDecAttention.v.weight', 'model.decoder.block.3.layer.2.DenseReluDense.wo.weight', 'model.decoder.block.4.layer.1.EncDecAttention.o.weight', 'model.decoder.block.2.layer.1.EncDecAttention.o.weight', 'model.decoder.block.4.layer.1.EncDecAttention.q.weight', 'model.shared.weight', 'model.decoder.block.4.layer.0.layer_norm.weight', 'model.encoder.block.5.layer.0.layer_norm.weight', 'model.encoder.block.5.layer.1.DenseReluDense.wi.weight', 'model.decoder.block.2.layer.1.EncDecAttention.q.weight', 'model.decoder.block.3.layer.1.layer_norm.weight', 'model.decoder.block.5.layer.1.layer_norm.weight', 'model.encoder.block.2.layer.1.layer_norm.weight', 'model.encoder.block.0.layer.0.SelfAttention.v.weight', 'model.encoder.block.2.layer.0.SelfAttention.k.weight', 'model.decoder.block.4.layer.0.SelfAttention.q.weight', 'model.encoder.block.1.layer.0.layer_norm.weight', 'model.decoder.block.1.layer.1.EncDecAttention.v.weight', 'model.decoder.block.0.layer.2.DenseReluDense.wi.weight', 'model.encoder.block.4.layer.0.SelfAttention.q.weight', 'model.decoder.block.5.layer.2.layer_norm.weight', 'model.decoder.block.5.layer.1.EncDecAttention.k.weight', 'model.encoder.block.3.layer.1.DenseReluDense.wi.weight', 'classification_head.out_proj.bias', 'model.encoder.block.5.layer.0.SelfAttention.q.weight', 'model.decoder.block.1.layer.0.SelfAttention.q.weight', 'model.decoder.block.1.layer.1.EncDecAttention.o.weight', 'model.decoder.block.1.layer.2.DenseReluDense.wo.weight', 'model.encoder.block.2.layer.0.layer_norm.weight', 'model.decoder.block.1.layer.0.SelfAttention.k.weight', 'model.decoder.block.5.layer.0.layer_norm.weight', 'model.encoder.block.1.layer.0.SelfAttention.k.weight', 'model.decoder.block.4.layer.2.layer_norm.weight', 'model.decoder.block.1.layer.2.layer_norm.weight', 'model.decoder.block.5.layer.0.SelfAttention.k.weight', 'model.decoder.block.2.layer.1.EncDecAttention.v.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

In case of BartForSequenceClassification with facebook/bart-large, this is how it looks like:

Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['classification_head.dense.weight', 'classification_head.dense.bias', 'classification_head.out_proj.weight', 'classification_head.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Am I missing out on something, or is it just that something extra has to be done to import the model? 🤔

@MetcalfeTom
Copy link
Author

MetcalfeTom commented Jan 24, 2022

@stefan-it / @subhalingamd the code for BartForSequenceClassification loads both the encoder and decoder parts for BART, which doesn't follow from the EncT5 paper - model should be T5Encoder only.

My solution is here, happy to push it, though it's a lot of duplicate code. Should some refactoring be performed between this and Bart?

The other addition from the EncT5 paper is that the encoder outputs are pooled to simulate the text-to-text nature of classification using NLG. This is not in my implementation but can be added.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@subhalingamd
Copy link
Contributor

@LysandreJik
can we expect this to be added sometime soon?

@osainz59
Copy link

Hi,
I don't know if anyone is still working on this, but looking to recent works that uses just the encoder from T5 I implemented the class T5ForSequenceClassification and evaluated on GLUE.

Here is the repository with the code and results: https://github.com/osainz59/t5-encoder

Is there still an interest on implementing this into Transformers library?

@stefan-it
Copy link
Collaborator

Hi @osainz59 , your repo looks really interesting, did you also perform experiments on NER? I've recently added support for encoder-only fine-tuning into Flair library, and results are really promising for this kind of downstream task. However, both sequence and token classification would be awesome to have it in Transformers 🤗

@osainz59
Copy link

Hi @stefan-it , I did not test the T5ForTokenClassification yet since there is no direct comparison afaik to traditional T5. However I can run it on some datasets to ensure it works properly. Can you tell me which datasets might be of interest to test?

@stefan-it
Copy link
Collaborator

stefan-it commented Dec 14, 2022

Hi @osainz59 I think one really interesting dataset would be the CoNLL-2003 (see https://huggingface.co/datasets/conll2003).

When testing the mT5 model series, the WikiANN (Rahimi splits from here: https://huggingface.co/datasets/wikiann) is also very interesting (train on English split only and test it on the other languages for comparisons with the mT5 paper) :)

@osainz59
Copy link

Hi @stefan-it , I trained and evaluated the T5ForTokenClassification class on CoNLL-2003 and here are the results:

***** eval metrics *****
  epoch                   =       25.0
  eval_accuracy           =     0.9916
  eval_f1                 =     0.9549
  eval_loss               =     0.0449
  eval_precision          =     0.9531
  eval_recall             =     0.9567
  eval_runtime            = 0:00:10.90
  eval_samples            =       3251
  eval_samples_per_second =    298.079
  eval_steps_per_second   =     37.317

I think they are still a bit behind of RoBERTa, but at those levels of F1 is hard to decide. Nevertheless, I think these results suggests that T5-Enc could be an interesting addition to the Transformers library.

@stefan-it
Copy link
Collaborator

Hey @osainz59 thanks for reporting back! I would love to see this in Transformers directly!

@sunyuhan19981208
Copy link

Hi, I don't know if anyone is still working on this, but looking to recent works that uses just the encoder from T5 I implemented the class T5ForSequenceClassification and evaluated on GLUE.

Here is the repository with the code and results: https://github.com/osainz59/t5-encoder

Is there still an interest on implementing this into Transformers library?

This is exactly what I am looking for!

@sjrl
Copy link
Contributor

sjrl commented Jul 24, 2023

I've opened a PR #24726 for T5ForSequenceClassification following the structure of the BartForSequenceClassification so both Encoder and Decoder weights are being used.

Although based on the results shown in this thread it seems like we could also look into adding a version that only uses the Encoder as well.

@hackyon
Copy link
Contributor

hackyon commented Oct 9, 2023

I've opened a PR #24726 for T5ForSequenceClassification following the structure of the BartForSequenceClassification so both Encoder and Decoder weights are being used.

Although based on the results shown in this thread it seems like we could also look into adding a version that only uses the Encoder as well.

FYI - Just created a pull request #26683 for EncT5, which is similar to T5ForSequenceClassification but with focuses on the encoding layers only (but still has a single decoder layer). It is based on results of https://arxiv.org/pdf/2110.08426.pdf.

@hackyon
Copy link
Contributor

hackyon commented Oct 11, 2023

Thanks @sjrl!

I wanted to circle back with more people on this thread to see if anyone else in the community is still interested in an encoder-variant EncT5 for the T5ForSequenceClassification model, as proposed in this paper.

I have a draft out in #26683, but since the paper is ~2 years old now, @ArthurZucker suggested we see if there is still interest in the community before moving forward.

@sunyuhan19981208 @stefan-it @osainz59 @subhalingamd @prajjwal1 - let us know if you think the encoder-variant EncT5 is worth adding to the library. Thanks!

@osainz59
Copy link

Hi @hackyon !

My implementation of encoer-only T5 for sequence and token classification still has some activity (clones and visitis). It is not much, but people definitely uses it.

@dwyatte
Copy link
Contributor

dwyatte commented Oct 11, 2023

@hackyon As another data point, in my experiments I found the EncT5 approach to work well for efficiently scaling models with encoders to >1B parameters and the single decoder layer worked much better than naively pooling the encoder outputs

I suspect most interest in scaling has shifted to decoder-only models, but if this can be done without a ton of "model sprawl" as you mention in your PR, I think it could be a nice addition

@hackyon
Copy link
Contributor

hackyon commented Oct 13, 2023

Thanks @osainz59 and @dwyatte for your inputs! I'll follow up with @ArthurZucker again to see if we can move forward.

PS. @dwyatte - could be worth checking out this comment from Frederick. Seems like there's more research/evidence that supports the single decoder layer approach (at least for multi-label classification in this case).

@mahita2104
Copy link

can i work on this ? could you assign this to me

@hackyon
Copy link
Contributor

hackyon commented Oct 23, 2023

Hello @mahita2104. I believe this issue has more or less been resolved already by @sjrl in #24726. You should be able to find T5ForSequenceClassification in the latest code base. I think they can probably mark this particular issue as closed if possible.

I am working on an extension to this in #26683, but am still iterating on it.

@ArthurZucker
Copy link
Collaborator

Yes closing this one thanks @hackyon 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests