diff --git a/docs/source/en/model_doc/gpt_neox.mdx b/docs/source/en/model_doc/gpt_neox.mdx index 1f5ec5e794b8db..90f6fe5d8cb274 100644 --- a/docs/source/en/model_doc/gpt_neox.mdx +++ b/docs/source/en/model_doc/gpt_neox.mdx @@ -78,3 +78,8 @@ The `generate()` method can be used to generate text using GPT Neo model. [[autodoc]] GPTNeoXForCausalLM - forward + +## GPTNeoXForSequenceClassification + +[[autodoc]] GPTNeoXForSequenceClassification + - forward \ No newline at end of file diff --git a/docs/source/en/tasks/sequence_classification.mdx b/docs/source/en/tasks/sequence_classification.mdx index 3126ce87e22152..6f8e0676bad273 100644 --- a/docs/source/en/tasks/sequence_classification.mdx +++ b/docs/source/en/tasks/sequence_classification.mdx @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2776b1988da4ed..f2733cad08231e 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1666,6 +1666,7 @@ [ "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPTNeoXForCausalLM", + "GPTNeoXForSequenceClassification", "GPTNeoXLayer", "GPTNeoXModel", "GPTNeoXPreTrainedModel", @@ -5164,6 +5165,7 @@ from .models.gpt_neox import ( GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPTNeoXForCausalLM, + GPTNeoXForSequenceClassification, GPTNeoXLayer, GPTNeoXModel, GPTNeoXPreTrainedModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 67a2206d46c03c..26fc67a4d313ee 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -659,6 +659,7 @@ ("gpt2", "GPT2ForSequenceClassification"), ("gpt_bigcode", "GPTBigCodeForSequenceClassification"), ("gpt_neo", "GPTNeoForSequenceClassification"), + ("gpt_neox", "GPTNeoXForSequenceClassification"), ("gptj", "GPTJForSequenceClassification"), ("ibert", "IBertForSequenceClassification"), ("layoutlm", "LayoutLMForSequenceClassification"), diff --git a/src/transformers/models/gpt_neox/__init__.py b/src/transformers/models/gpt_neox/__init__.py index db5c17996a007a..197036158a5e42 100644 --- a/src/transformers/models/gpt_neox/__init__.py +++ b/src/transformers/models/gpt_neox/__init__.py @@ -36,6 +36,7 @@ _import_structure["modeling_gpt_neox"] = [ "GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST", "GPTNeoXForCausalLM", + "GPTNeoXForSequenceClassification", "GPTNeoXLayer", "GPTNeoXModel", "GPTNeoXPreTrainedModel", @@ -62,6 +63,7 @@ from .modeling_gpt_neox import ( GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST, GPTNeoXForCausalLM, + GPTNeoXForSequenceClassification, GPTNeoXLayer, GPTNeoXModel, GPTNeoXPreTrainedModel, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 5c2271ebb9f77b..fdab541263bbe4 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -19,7 +19,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...file_utils import ( @@ -28,7 +28,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_gpt_neox import GPTNeoXConfig @@ -730,3 +730,131 @@ def _reorder_cache(self, past_key_values, beam_idx): tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) return reordered_past + + +@add_start_docstrings( + """ + The GPTNeoX Model transformer with a sequence classification head on top (linear layer). + + [`GPTNeoXForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT_NEOX_START_DOCSTRING, +) +class GPTNeoXForSequenceClassification(GPTNeoXPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.gpt_neox = GPTNeoXModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + print(pooled_logits.shape, labels.shape) + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 0f7886fd844e18..9524cf58621a1b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3236,6 +3236,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class GPTNeoXForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class GPTNeoXLayer(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 519b10a040ff30..707e224e935641 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -29,7 +29,7 @@ if is_torch_available(): import torch - from transformers import GPTNeoXForCausalLM, GPTNeoXModel + from transformers import GPTNeoXForCausalLM, GPTNeoXForSequenceClassification, GPTNeoXModel class GPTNeoXModelTester: @@ -80,6 +80,7 @@ def __init__( self.num_labels = num_labels self.num_choices = num_choices self.scope = scope + self.pad_token_id = vocab_size - 1 def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -110,6 +111,7 @@ def get_config(self): type_vocab_size=self.type_vocab_size, is_decoder=False, initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, ) def prepare_config_and_inputs_for_decoder(self): @@ -142,6 +144,15 @@ def create_and_check_for_causal_lm(self, config, input_ids, input_mask, token_la result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_for_sequence_classification(self, config, input_ids, input_mask, token_labels): + config.num_labels = self.num_labels + model = GPTNeoXForSequenceClassification(config) + model.to(torch_device) + model.eval() + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def create_and_check_decoder_model_past_large_inputs(self, config, input_ids, input_mask): config.is_decoder = True model = GPTNeoXForCausalLM(config=config) @@ -188,10 +199,19 @@ def prepare_config_and_inputs_for_common(self): @require_torch class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (GPTNeoXModel, GPTNeoXForCausalLM) if is_torch_available() else () + all_model_classes = ( + (GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXForSequenceClassification) if is_torch_available() else () + ) all_generative_model_classes = (GPTNeoXForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( - {"feature-extraction": GPTNeoXModel, "text-generation": GPTNeoXForCausalLM} if is_torch_available() else {} + { + "feature-extraction": GPTNeoXModel, + "text-classification": GPTNeoXForSequenceClassification, + "text-generation": GPTNeoXForCausalLM, + "zero-shot": GPTNeoXForSequenceClassification, + } + if is_torch_available() + else {} ) test_pruning = False test_missing_keys = False @@ -229,6 +249,10 @@ def test_model_for_causal_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + def test_model_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + @unittest.skip(reason="Feed forward chunking is not implemented") def test_feed_forward_chunking(self): pass