Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Sequence Classification class in GPTNeo #11906

Merged
merged 2 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datasets
Submodule datasets added at d95b95
6 changes: 6 additions & 0 deletions docs/source/model_doc/gpt_neo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ GPTNeoForCausalLM

.. autoclass:: transformers.GPTNeoForCausalLM
:members: forward

GPTNeoForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.GPTNeoForSequenceClassification
:members: forward
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@
[
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoModel",
"GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo",
Expand Down Expand Up @@ -2129,6 +2130,7 @@
from .models.gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
FunnelModel,
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoModel
from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoForSequenceClassification, GPTNeoModel
from ..ibert.modeling_ibert import (
IBertForMaskedLM,
IBertForMultipleChoice,
Expand Down Expand Up @@ -632,6 +632,7 @@
(DebertaConfig, DebertaForSequenceClassification),
(DebertaV2Config, DebertaV2ForSequenceClassification),
(GPT2Config, GPT2ForSequenceClassification),
(GPTNeoConfig, GPTNeoForSequenceClassification),
(OpenAIGPTConfig, OpenAIGPTForSequenceClassification),
(ReformerConfig, ReformerForSequenceClassification),
(CTRLConfig, CTRLForSequenceClassification),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gpt_neo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_import_structure["modeling_gpt_neo"] = [
"GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTNeoForCausalLM",
"GPTNeoForSequenceClassification",
"GPTNeoModel",
"GPTNeoPreTrainedModel",
"load_tf_weights_in_gpt_neo",
Expand All @@ -41,6 +42,7 @@
from .modeling_gpt_neo import (
GPT_NEO_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
Expand Down
120 changes: 119 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
Expand All @@ -31,6 +31,7 @@
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...utils import logging
Expand Down Expand Up @@ -1027,3 +1028,120 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)


@add_start_docstrings(
"""
The GPTNeo Model transformer with a sequence classification head on top (linear layer).

:class:`~transformers.GPTNeoForSequenceClassification` uses the last token in order to do the classification, as
other causal models (e.g. GPT-1) do.

Since it does classification on the last token, it requires to know the position of the last token. If a
:obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same (take
the last value in each row of the batch).
""",
GPT_NEO_START_DOCSTRING,
)
class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = GPTNeoModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

self.init_weights()

@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=SequenceClassifierOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[range(batch_size), sequence_lengths]

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
9 changes: 9 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,15 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["torch"])


class GPTNeoModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
Expand Down
1 change: 0 additions & 1 deletion tests/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,6 @@ def create_and_check_gpt2_for_sequence_classification(
model = GPT2ForSequenceClassification(config)
model.to(torch_device)
model.eval()
print(config.num_labels, sequence_labels.size())
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

Expand Down
19 changes: 18 additions & 1 deletion tests/test_modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GPT2Tokenizer,
GPTNeoConfig,
GPTNeoForCausalLM,
GPTNeoForSequenceClassification,
GPTNeoModel,
)
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoAttentionMixin
Expand Down Expand Up @@ -238,6 +239,16 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_gpt_neo_for_sequence_classification(
self, config, input_ids, input_mask, head_mask, token_type_ids, mc_token_ids, sequence_labels, *args
):
Comment on lines +242 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add a test for this in GPTNeoModelTest class, otherwise, it won't run.
see the GPT2 test

def test_gpt2_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_for_sequence_classification(*config_and_inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!
I will fix it.

config.num_labels = self.num_labels
model = GPTNeoForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))

def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPTNeoForCausalLM(config)
model.to(torch_device)
Expand Down Expand Up @@ -274,7 +285,9 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class GPTNeoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):

all_model_classes = (GPTNeoModel, GPTNeoForCausalLM) if is_torch_available() else ()
all_model_classes = (
(GPTNeoModel, GPTNeoForCausalLM, GPTNeoForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (GPTNeoForCausalLM,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
test_missing_keys = False
Expand Down Expand Up @@ -305,6 +318,10 @@ def test_gpt_neo_lm_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)

def test_gpt_neo_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs)
Comment on lines +321 to +323
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


def test_gpt_neo_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
Expand Down