diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a507a51e20..063fc63f5a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2596,6 +2596,7 @@ "GPT2AdapterModel", "GPT2ModelWithHeads", "GPTJAdapterModel", + "GPTNeoXAdapterModel", "HoulsbyConfig", "HoulsbyInvConfig", "IA3Config", @@ -3456,6 +3457,10 @@ ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] ) _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) + _import_structure["models.gpt_neox"].extend( + ["FlaxGPTNeoXForCausalLM", "FlaxGPTNeoXModel", "FlaxGPTNeoXPreTrainedModel"] + ) + _import_structure["models.longt5"].extend( ["FlaxLongT5ForConditionalGeneration", "FlaxLongT5Model", "FlaxLongT5PreTrainedModel"] ) @@ -5700,6 +5705,7 @@ GPT2AdapterModel, GPT2ModelWithHeads, GPTJAdapterModel, + GPTNeoXAdapterModel, HoulsbyConfig, HoulsbyInvConfig, IA3Config, @@ -6399,6 +6405,7 @@ from .models.encoder_decoder import FlaxEncoderDecoderModel from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel + from .models.gpt_neox import FlaxGPTNeoXForCausalLM, FlaxGPTNeoXModel, FlaxGPTNeoXPreTrainedModel from .models.gptj import FlaxGPTJForCausalLM, FlaxGPTJModel, FlaxGPTJPreTrainedModel from .models.longt5 import FlaxLongT5ForConditionalGeneration, FlaxLongT5Model, FlaxLongT5PreTrainedModel from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel diff --git a/src/transformers/adapters/__init__.py b/src/transformers/adapters/__init__.py index a07b8e8469..3842c5fe37 100644 --- a/src/transformers/adapters/__init__.py +++ b/src/transformers/adapters/__init__.py @@ -113,6 +113,10 @@ "GPT2AdapterModel", "GPT2ModelWithHeads", ], + "models.gpt_neox": [ + "GPTNeoXAdapterModel", + "GPTNeoXModelWithHeads", + ], "models.gptj": ["GPTJAdapterModel"], "models.mbart": [ "MBartAdapterModel", @@ -217,6 +221,7 @@ from .models.debertaV2 import DebertaV2AdapterModel from .models.distilbert import DistilBertAdapterModel, DistilBertModelWithHeads from .models.gpt2 import GPT2AdapterModel, GPT2ModelWithHeads + from .models.gpt_neox import GPTNeoXAdapterModel, GPTNeoXModelWithHeads from .models.gptj import GPTJAdapterModel from .models.mbart import MBartAdapterModel, MBartModelWithHeads from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads diff --git a/src/transformers/adapters/head_utils.py b/src/transformers/adapters/head_utils.py index 05f3a54566..cf38f3d7f8 100644 --- a/src/transformers/adapters/head_utils.py +++ b/src/transformers/adapters/head_utils.py @@ -381,6 +381,13 @@ }, "layers": [None, "classifier"], }, + # GPT-NeoX + "GPTNeoXForCausalLM": { + "config": { + "head_type": "causal_lm", + }, + "layers": ["embed_out"], + }, # GPT-J "GPTJForSequenceClassification": { "config": { diff --git a/src/transformers/adapters/mixins/gpt_neox.py b/src/transformers/adapters/mixins/gpt_neox.py new file mode 100644 index 0000000000..ff520ee9aa --- /dev/null +++ b/src/transformers/adapters/mixins/gpt_neox.py @@ -0,0 +1,32 @@ +from typing import Iterable, Tuple + +import torch.nn as nn + +from ..layer import AdapterLayer +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) + + +class GPTNeoXDecoderBlockAdaptersMixin: + """Adds adapters to the TransformerBlock module of DistilBert.""" + + def _init_adapter_modules(self): + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) + self.attention_adapters._init_adapter_modules() + self.output_adapters._init_adapter_modules() + + +class GPTNeoXModelAdapterMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin): + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.base_model.layers): + yield i, layer + + +class GPTNeoXModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/models/auto/adapter_model.py b/src/transformers/adapters/models/auto/adapter_model.py index cfd159bad6..274cdf4b0f 100644 --- a/src/transformers/adapters/models/auto/adapter_model.py +++ b/src/transformers/adapters/models/auto/adapter_model.py @@ -20,6 +20,7 @@ ("bart", "BartAdapterModel"), ("mbart", "MBartAdapterModel"), ("gpt2", "GPT2AdapterModel"), + ("gpt_neox", "GPTNeoXAdapterModel"), ("gptj", "GPTJAdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), @@ -34,6 +35,7 @@ ("bart", "BartModelWithHeads"), ("mbart", "MBartModelWithHeads"), ("gpt2", "GPT2ModelWithHeads"), + ("gpt_neox", "GPTNeoXModelWithHeads"), ("t5", "T5ModelWithHeads"), ] ) diff --git a/src/transformers/adapters/models/gpt_neox/__init__.py b/src/transformers/adapters/models/gpt_neox/__init__.py new file mode 100644 index 0000000000..447dc9d604 --- /dev/null +++ b/src/transformers/adapters/models/gpt_neox/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ....utils import _LazyModule + + +_import_structure = { + "adapter_model": ["GPTNeoXAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import GPTNeoXAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/transformers/adapters/models/gpt_neox/adapter_model.py b/src/transformers/adapters/models/gpt_neox/adapter_model.py new file mode 100644 index 0000000000..8573155f04 --- /dev/null +++ b/src/transformers/adapters/models/gpt_neox/adapter_model.py @@ -0,0 +1,132 @@ +import logging + +import torch + +from ....models.gpt_neox.modeling_gpt_neox import GPT_NEOX_START_DOCSTRING, GPTNeoXModel, GPTNeoXPreTrainedModel +from ....utils import add_start_docstrings +from ...composition import adjust_tensors_for_parallel +from ...heads import CausalLMHead, ModelWithFlexibleHeadsAdaptersMixin +from ...model_mixin import EmbeddingAdaptersWrapperMixin + + +logger = logging.getLogger(__name__) + + +@add_start_docstrings( + """ +The GPTNeoX Model that allows the loading of different heads for different tasks. This enables a flexible use of the +models and adapters. Since this class does classification on the last token, it requires to know the position of the +last token. If a :obj:`pad_token_id` is defined in the configuration, it finds the last token that is not a padding +token in each row. If no :obj:`pad_token_id` is defined, it simply takes the last value in each row of the batch. Since +it cannot guess the padding tokens when :obj:`inputs_embeds` are passed instead of :obj:`input_ids`, it does the same +(take the last value in each row of the batch). +""", + GPT_NEOX_START_DOCSTRING, +) +class GPTNeoXAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTNeoXPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.gpt_neox = GPTNeoXModel(config) + + self._init_head_modules() + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.gpt_neox( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + ) + + batch_size = outputs[0].shape[0] + + if self.config.pad_token_id is None: + # TODO-AH: this may result in unexpected behavior for classification. Find a better way to do this? + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + (sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + cls_logits = outputs[0][range(batch_size), sequence_lengths] + + outputs = self.forward_head( + outputs, + head_name=head, + cls_output=cls_logits, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + + return outputs + + # Copied from GPTNeoXForCausalLM + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values and past_key_values[0] is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + } + + head_types = {"causal_lm": CausalLMHead} + + def add_causal_lm_head(self, head_name, overwrite_ok=False): + """ + Adds a causal language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = CausalLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 56da1694bd..ac68df8591 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -35,6 +35,10 @@ "hidden_dropout_prob": "resid_pdrop", "attention_probs_dropout_prob": "attn_pdrop", }, + "gpt_neox": { + "hidden_dropout_prob": "hidden_dropout_prob", + "attention_probs_dropout_prob": "attention_probs_dropout_prob", + }, "gptj": { "hidden_dropout_prob": "resid_pdrop", "attention_probs_dropout_prob": "attn_pdrop", diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ce3de79bd4..2786dbc23d 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -209,6 +209,7 @@ ("funnel", "FunnelForPreTraining"), ("gpt-sw3", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"), + ("gpt_neox", "GPTNeoXForCausalLM"), ("ibert", "IBertForMaskedLM"), ("layoutlm", "LayoutLMForMaskedLM"), ("longformer", "LongformerForMaskedLM"), diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 589eaae780..f39b750aca 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -22,6 +22,16 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext +from ...adapters.lora import Linear as LoRALinear +from ...adapters.lora import MergedLinear as LoRAMergedLinear +from ...adapters.mixins.gpt_neox import ( + GPTNeoXDecoderBlockAdaptersMixin, + GPTNeoXModelAdapterMixin, + GPTNeoXModelWithHeadsAdaptersMixin, +) +from ...adapters.prefix_tuning import PrefixTuningShim from ...file_utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -83,6 +93,7 @@ def __init__(self, config): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) + self.prefix_tuning = PrefixTuningShim("self_prefix", config) max_positions = config.max_position_embeddings self.register_buffer( "bias", @@ -95,7 +106,13 @@ def __init__(self, config): self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base ) self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) - self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) + self.query_key_value = LoRAMergedLinear( + config.hidden_size, + 3 * config.hidden_size, + "selfattn", + config, + fan_in_fan_out=False, + ) self.dense = nn.Linear(config.hidden_size, config.hidden_size) def forward( @@ -149,6 +166,9 @@ def forward( value = torch.cat((past_value, value), dim=-2) present = (key, value) if use_cache else None + key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask) + (query,) = adjust_tensors_for_parallel(key, query) + # Compute attention attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) @@ -286,8 +306,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): class GPTNeoXMLP(nn.Module): def __init__(self, config): super().__init__() - self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) - self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) + self.dense_h_to_4h = LoRALinear(config.hidden_size, config.intermediate_size, "intermediate", config) + self.dense_4h_to_h = LoRALinear(config.intermediate_size, config.hidden_size, "output", config) self.act = ACT2FN[config.hidden_act] def forward(self, hidden_states): @@ -297,14 +317,16 @@ def forward(self, hidden_states): return hidden_states -class GPTNeoXLayer(nn.Module): +class GPTNeoXLayer(GPTNeoXDecoderBlockAdaptersMixin, nn.Module): def __init__(self, config): super().__init__() + self.config = config self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attention = GPTNeoXAttention(config) self.mlp = GPTNeoXMLP(config) + self._init_adapter_modules() def forward( self, @@ -331,14 +353,23 @@ def forward( # pseudocode: # x = x + attn(ln1(x)) + mlp(ln2(x)) mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = mlp_output + attn_output + hidden_states + # See https://github.com/adapter-hub/adapter-transformers/pull/426#discussion_r994450898 + hidden_states = self.attention_adapters(attn_output, hidden_states, None) + hidden_states = self.output_adapters(mlp_output, hidden_states, None) + # hidden_states = mlp_output + attn_output + hidden_states + else: # pseudocode: # x = x + attn(ln1(x)) # x = x + mlp(ln2(x)) - attn_output = attn_output + hidden_states - mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - hidden_states = mlp_output + attn_output + hidden_states = self.attention_adapters( + attn_output, hidden_states, None + ) # attn_output = attn_output + hidden_states + residual = hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + # residual connection + hidden_states = self.output_adapters(mlp_output, residual, None) + # hidden_states = mlp_output + attn_output if use_cache: outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) @@ -413,7 +444,7 @@ def forward( "The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.", GPT_NEOX_START_DOCSTRING, ) -class GPTNeoXModel(GPTNeoXPreTrainedModel): +class GPTNeoXModel(GPTNeoXModelAdapterMixin, GPTNeoXPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config @@ -424,6 +455,8 @@ def __init__(self, config): self.gradient_checkpointing = False + self._init_adapter_modules() + # Initialize weights and apply final processing self.post_init() @@ -440,6 +473,7 @@ def set_input_embeddings(self, value): output_type=BaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -511,7 +545,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_in(input_ids) - + inputs_embeds = self.invertible_adapters_forward(inputs_embeds) hidden_states = inputs_embeds presents = () if use_cache else None @@ -552,6 +586,8 @@ def custom_forward(*inputs): output_attentions=output_attentions, ) hidden_states = outputs[0] + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) + if use_cache is True: presents = presents + (outputs[1],) if output_attentions: @@ -576,7 +612,7 @@ def custom_forward(*inputs): @add_start_docstrings( """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING ) -class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): +class GPTNeoXForCausalLM(GPTNeoXModelWithHeadsAdaptersMixin, GPTNeoXPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 7e951fdb19..29780928e6 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -117,6 +117,7 @@ def _generate_supported_model_class_names( "electra", "gpt2", "gpt_neo", + "gpt_neox", "gptj", "hubert", "layoutlm", diff --git a/tests_adapters/test_gpt_neox.py b/tests_adapters/test_gpt_neox.py new file mode 100644 index 0000000000..24dfab7107 --- /dev/null +++ b/tests_adapters/test_gpt_neox.py @@ -0,0 +1,63 @@ +import unittest + +from transformers import GPTNeoXConfig +from transformers.testing_utils import require_torch + +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class GPTNeoXAdapterTestBase(AdapterTestBase): + config_class = GPTNeoXConfig + config = make_config( + GPTNeoXConfig, + n_embd=32, + n_layer=4, + n_head=4, + # set pad token to eos token + pad_token_id=50256, + resid_pdrop=0.1, + ) + tokenizer_name = "EleutherAI/gpt-neox-20b" + + +@require_torch +class GPTNeoXAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + UniPELTTestMixin, + PrefixTuningTestMixin, + EmbeddingTestMixin, + CompabilityTestMixin, + AdapterFusionModelTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + GPTNeoXAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class GPTNeoXClassConversionTest( + ModelClassConversionTestMixin, + GPTNeoXAdapterTestBase, + unittest.TestCase, +): + pass