Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GPTNeoXForSequenceClassification #22671

Merged
merged 3 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/gpt_neox.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,8 @@ The `generate()` method can be used to generate text using GPT Neo model.

[[autodoc]] GPTNeoXForCausalLM
- forward

## GPTNeoXForSequenceClassification

[[autodoc]] GPTNeoXForSequenceClassification
- forward
2 changes: 1 addition & 1 deletion docs/source/en/tasks/sequence_classification.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)


<!--End of the generated tip-->
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,7 @@
[
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification",
"GPTNeoXLayer",
"GPTNeoXModel",
"GPTNeoXPreTrainedModel",
Expand Down Expand Up @@ -5164,6 +5165,7 @@
from .models.gpt_neox import (
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXLayer,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@
("gpt2", "GPT2ForSequenceClassification"),
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
("gpt_neo", "GPTNeoForSequenceClassification"),
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gpt_neox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_import_structure["modeling_gpt_neox"] = [
"GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoXForCausalLM",
"GPTNeoXForSequenceClassification",
"GPTNeoXLayer",
"GPTNeoXModel",
"GPTNeoXPreTrainedModel",
Expand All @@ -62,6 +63,7 @@
from .modeling_gpt_neox import (
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GPTNeoXLayer,
GPTNeoXModel,
GPTNeoXPreTrainedModel,
Expand Down
132 changes: 130 additions & 2 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import (
Expand All @@ -28,7 +28,7 @@
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from .configuration_gpt_neox import GPTNeoXConfig
Expand Down Expand Up @@ -730,3 +730,131 @@ def _reorder_cache(self, past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past


@add_start_docstrings(
"""
The GPTNeoX Model transformer with a sequence classification head on top (linear layer).

[`GPTNeoXForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-1) do.

Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
GPT_NEOX_START_DOCSTRING,
)
class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.gpt_neox = GPTNeoXModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
sgugger marked this conversation as resolved.
Show resolved Hide resolved
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
print(pooled_logits.shape, labels.shape)
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoXForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoXLayer(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
30 changes: 27 additions & 3 deletions tests/models/gpt_neox/test_modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
if is_torch_available():
import torch

from transformers import GPTNeoXForCausalLM, GPTNeoXModel
from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel


class GPTNeoXModelTester:
Expand Down Expand Up @@ -80,6 +80,7 @@ def __init__(
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.pad_token_id = vocab_size - 1

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
Expand Down Expand Up @@ -110,6 +111,7 @@ def get_config(self):
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)

def prepare_config_and_inputs_for_decoder(self):
Expand Down Expand Up @@ -142,6 +144,15 @@ def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_la
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels):
config.num_labels = self.num_labels
model = GPTNeoXForSequenceClassification(config)
model.to(torch_device)
model.eval()
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask):
config.is_decoder = True
model = GPTNeoXForCausalLM(config=config)
Expand Down Expand Up @@ -188,10 +199,19 @@ def prepare_config_and_inputs_for_common(self):

@require_torch
class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else ()
all_model_classes = (
(GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": GPTNeoXModel, "text-generation": GPTNeoXForCausalLM} if is_torch_available() else {}
{
"feature-extraction": GPTNeoXModel,
"text-classification": GPTNeoXForSequenceClassification,
"text-generation": GPTNeoXForCausalLM,
"zero-shot": GPTNeoXForSequenceClassification,
}
if is_torch_available()
else {}
)
test_pruning = False
test_missing_keys = False
Expand Down Expand Up @@ -229,6 +249,10 @@ def test_model_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)

def test_model_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)

@unittest.skip(reason="Feed forward chunking is not implemented")
def test_feed_forward_chunking(self):
pass
Expand Down