Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding adapter support for NeoX #523

Open
wants to merge 13 commits into
base: legacy
Choose a base branch
from
7 changes: 7 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,7 @@
"GPT2AdapterModel",
"GPT2ModelWithHeads",
"GPTJAdapterModel",
"GPTNeoXAdapterModel",
"HoulsbyConfig",
"HoulsbyInvConfig",
"IA3Config",
Expand Down Expand Up @@ -3456,6 +3457,10 @@
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
)
_import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"])
_import_structure["models.gpt_neox"].extend(
["FlaxGPTNeoXForCausalLM", "FlaxGPTNeoXModel", "FlaxGPTNeoXPreTrainedModel"]
)

_import_structure["models.longt5"].extend(
["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"]
)
Expand Down Expand Up @@ -5700,6 +5705,7 @@
GPT2AdapterModel,
GPT2ModelWithHeads,
GPTJAdapterModel,
GPTNeoXAdapterModel,
HoulsbyConfig,
HoulsbyInvConfig,
IA3Config,
Expand Down Expand Up @@ -6399,6 +6405,7 @@
from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.gpt_neox import FlaxGPTNeoXForCausalLM, FlaxGPTNeoXModel, FlaxGPTNeoXPreTrainedModel
from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel
from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@
"GPT2AdapterModel",
"GPT2ModelWithHeads",
],
"models.gpt_neox": [
"GPTNeoXAdapterModel",
"GPTNeoXModelWithHeads",
],
"models.gptj": ["GPTJAdapterModel"],
"models.mbart": [
"MBartAdapterModel",
Expand Down Expand Up @@ -217,6 +221,7 @@
from .models.debertaV2 import DebertaV2AdapterModel
from .models.distilbert import DistilBertAdapterModel, DistilBertModelWithHeads
from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads
from .models.gpt_neox import GPTNeoXAdapterModel, GPTNeoXModelWithHeads
from .models.gptj import GPTJAdapterModel
from .models.mbart import MBartAdapterModel, MBartModelWithHeads
from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@
},
"layers": [None, "classifier"],
},
# GPT-NeoX
"GPTNeoXForCausalLM": {
"config": {
"head_type": "causal_lm",
},
"layers": ["embed_out"],
},
# GPT-J
"GPTJForSequenceClassification": {
"config": {
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/adapters/mixins/gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)


class GPTNeoXDecoderBlockAdaptersMixin:
"""Adds adapters to the TransformerBlock module of DistilBert."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class GPTNeoXModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin):
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.base_model.layers):
yield i, layer


class GPTNeoXModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
2 changes: 2 additions & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
("bart", "BartAdapterModel"),
("mbart", "MBartAdapterModel"),
("gpt2", "GPT2AdapterModel"),
("gpt_neox", "GPTNeoXAdapterModel"),
("gptj", "GPTJAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
Expand All @@ -34,6 +35,7 @@
("bart", "BartModelWithHeads"),
("mbart", "MBartModelWithHeads"),
("gpt2", "GPT2ModelWithHeads"),
("gpt_neox", "GPTNeoXModelWithHeads"),
("t5", "T5ModelWithHeads"),
]
)
Expand Down
39 changes: 39 additions & 0 deletions src/transformers/adapters/models/gpt_neox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": ["GPTNeoXAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import GPTNeoXAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
132 changes: 132 additions & 0 deletions src/transformers/adapters/models/gpt_neox/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging

import torch

from ....models.gpt_neox.modeling_gpt_neox import GPT_NEOX_START_DOCSTRING, GPTNeoXModel, GPTNeoXPreTrainedModel
from ....utils import add_start_docstrings
from ...composition import adjust_tensors_for_parallel
from ...heads import CausalLMHead, ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin


logger = logging.getLogger(__name__)


@add_start_docstrings(
"""
The GPTNeoX Model that allows the loading of different heads for different tasks. This enables a flexible use of the
models and adapters. Since this class does classification on the last token, it requires to know the position of the
last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding
token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same
(take the last value in each row of the batch).
""",
GPT_NEOX_START_DOCSTRING,
)
class GPTNeoXAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTNeoXPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gpt_neox = GPTNeoXModel(config)

self._init_head_modules()

self.init_weights()

# Model parallel
self.model_parallel = False
self.device_map = None

def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.gpt_neox(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
)

batch_size = outputs[0].shape[0]

if self.config.pad_token_id is None:
# TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this?
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
(sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths)
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

cls_logits = outputs[0][range(batch_size), sequence_lengths]

outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_logits,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)

return outputs

# Copied from GPTNeoXForCausalLM
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
}

head_types = {"causal_lm": CausalLMHead}

def add_causal_lm_head(self, head_name, overwrite_ok=False):
"""
Adds a causal language modeling head on top of the model.

Args:
head_name (str): The name of the head.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = CausalLMHead(self, head_name)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)
4 changes: 4 additions & 0 deletions src/transformers/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"hidden_dropout_prob": "resid_pdrop",
"attention_probs_dropout_prob": "attn_pdrop",
},
"gpt_neox": {
"hidden_dropout_prob": "hidden_dropout_prob",
"attention_probs_dropout_prob": "attention_probs_dropout_prob",
},
"gptj": {
"hidden_dropout_prob": "resid_pdrop",
"attention_probs_dropout_prob": "attn_pdrop",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
("funnel", "FunnelForPreTraining"),
("gpt-sw3", "GPT2LMHeadModel"),
("gpt2", "GPT2LMHeadModel"),
("gpt_neox", "GPTNeoXForCausalLM"),
("ibert", "IBertForMaskedLM"),
("layoutlm", "LayoutLMForMaskedLM"),
("longformer", "LongformerForMaskedLM"),
Expand Down
Loading