Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added GPTNeoForTokenClassification #22908

Merged
merged 20 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/gpt_neo.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ The `generate()` method can be used to generate text using GPT Neo model.
[[autodoc]] GPTNeoForSequenceClassification
- forward

## GPTNeoForTokenClassification

[[autodoc]] GPTNeoForTokenClassification
- forward

## FlaxGPTNeoModel

[[autodoc]] FlaxGPTNeoModel
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tasks/token_classification.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [QDQBert](../model_doc/qdqbert), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)

<!--End of the generated tip-->

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoForTokenClassification",
"GPTNeoModel",
"GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo",
Expand Down Expand Up @@ -5221,6 +5222,7 @@
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@
("gpt-sw3", "GPT2ForTokenClassification"),
("gpt2", "GPT2ForTokenClassification"),
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
("gpt_neo", "GPTNeoForTokenClassification"),
("ibert", "IBertForTokenClassification"),
("layoutlm", "LayoutLMForTokenClassification"),
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gpt_neo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoForTokenClassification",
"GPTNeoModel",
"GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo",
Expand Down Expand Up @@ -61,6 +62,7 @@
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
Expand Down
99 changes: 99 additions & 0 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
Expand Down Expand Up @@ -926,3 +927,101 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


@add_start_docstrings(
"""
GPT Neo model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForTokenClassification(GPTNeoPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.transformer = GPTNeoModel(config)
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
peter-sk marked this conversation as resolved.
Show resolved Hide resolved
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

# Model parallel
self.model_parallel = False
self.device_map = None
peter-sk marked this conversation as resolved.
Show resolved Hide resolved

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
# fmt: off
@add_code_sample_docstrings(
checkpoint="EleutherAI/gpt-neo-125m",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=0.25,
expected_output=["Lead", "Lead", "Lead", "Position", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead", "Lead"],
peter-sk marked this conversation as resolved.
Show resolved Hide resolved
)
# fmt: on
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = transformer_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)

loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3273,6 +3273,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoForTokenClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
20 changes: 19 additions & 1 deletion tests/models/gpt_neo/test_modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
GPT2Tokenizer,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoForTokenClassification,
GPTNeoModel,
)

Expand Down Expand Up @@ -334,6 +335,16 @@ def create_and_check_gpt_neo_for_sequence_classification(
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_gpt_neo_for_token_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
config.num_labels = self.num_labels
model = GPTNeoForTokenClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))

def create_and_check_forward_and_backwards(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
):
Expand Down Expand Up @@ -374,13 +385,16 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoForTokenClassification)
if is_torch_available()
else ()
)
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": GPTNeoModel,
"text-classification": GPTNeoForSequenceClassification,
"token-classification": GPTNeoForTokenClassification,
"text-generation": GPTNeoForCausalLM,
"zero-shot": GPTNeoForSequenceClassification,
}
Expand Down Expand Up @@ -428,6 +442,10 @@ def test_gpt_neo_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)

def test_gpt_neo_token_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_token_classification(*config_and_inputs)

def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
Expand Down