Skip to content

Commit

Permalink
Upgrade Transformers to v4.38.x (#654)
Browse files Browse the repository at this point in the history
Changes:
- HF changed parts of the Llama model implementation
- HF added a `LlamaForQuestionAnswering`. However, this model has a
wrong base model name. I added a workaround that solves this problem
until this is fixed in Transformers
(huggingface/transformers#29258)

---------

Co-authored-by: calpt <calpt@mail.de>
  • Loading branch information
lenglaender and calpt authored Apr 6, 2024
1 parent 93fff96 commit a9152e7
Show file tree
Hide file tree
Showing 22 changed files with 107 additions and 106 deletions.
4 changes: 2 additions & 2 deletions docs/adapter_composition.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ model.active_adapters = "adapter_name"
- You cannot activate an adapter before previously adding it to the model using either ``add_adapter()`` or ``load_adapter()``.
- All adapters not mentioned in the ``active_adapters`` setup are ignored, although they might have been loaded into the model. Thus, after adding an adapter, make sure to activate it.
```
Note that we also could have used the [`set_active_adapters`](adapters.) method with `model.set_active_adapters("adapter_name")` which does the same.
Note that we also could have used the `set_active_adapters` method with `model.set_active_adapters("adapter_name")` which does the same.

Alternatively, the [`AdapterSetup`](adapters.AdapterSetup) context manager allows dynamic configuration of activated setups without changing the model state:

Expand Down Expand Up @@ -125,7 +125,7 @@ model.active_adapters = ac.Fuse("d", "e", "f")

To learn how training an _AdapterFusion_ layer works, check out [this Colab notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/03_Adapter_Fusion.ipynb) from the `adapters` repo.

#### Retrieving AdapterFusion attentions
### Retrieving AdapterFusion attentions

Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model.
These scores can be used for analyzing the fused adapter blocks and can serve as the basis for visualizations similar to those in the AdapterFusion paper.
Expand Down
3 changes: 3 additions & 0 deletions docs/classes/models/auto.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Auto Classes
Similar to the ``AutoModel`` classes built-in into HuggingFace Transformers, adapters provides an ``AutoAdapterModel`` class.
As with other auto classes, the correct adapter model class is automatically instantiated based on the pre-trained model passed to the ``from_pretrained()`` method.

.. note::
If the model loaded with the ``from_pretrained(...)`` function has a head, this head gets loaded as well. However, this only works for non-sharded models. If you want to load a sharded model with a head, you first need to load the model and then the head separately.

AutoAdapterModel
~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion docs/classes/models/bart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ BartAdapterModel

.. autoclass:: adapters.BartAdapterModel
:members:
:inherited-members: BartPretrainedModel
:inherited-members: BartPreTrainedModel
2 changes: 1 addition & 1 deletion docs/classes/models/electra.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ELECTRA
======
=======

The ELECTRA model was proposed in the paper `ELECTRA: Pre-training Text Encoders as Discriminators Rather Than
Generators <https://openreview.net/pdf?id=r1xMH1BtvB>`__. ELECTRA is a new pretraining approach which trains two
Expand Down
5 changes: 5 additions & 0 deletions docs/classes/models/llama.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
LLaMA
-----------------------------------------------------------------------------------------------------------------------

.. note::
Loading a ``LlamaForQuestionAnswering`` via [`AutoAdapterModel`](adapters.AutoAdapterModel) or via [`LlamaAdapterModel`](adapters.LlamaAdapterModel) does not load the head, even if the model is not sharded. Please load the base model first and then subsequently the head.
Note that for sharded models the head is never automatically loaded as described here: [Auto Classes](auto.rst)


The LLaMA model was proposed in `LLaMA: Open and Efficient Foundation Language Models <https://arxiv.org/abs/2302.13971>`__ by
Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal,
Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, Guillaume Lample. It is a collection of foundation language
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@

def skip_head_member(app, what, name, obj, skip, options):
if type(obj).__name__ == "function" and "inherited-members" in options and (m := re.match(r"add\_(.*)\_head$", name)):
cls_name = options["inherited-members"].replace("PreTrainedModel", "AdapterModel").replace("PretrainedModel", "AdapterModel")
cls_name = list(options["inherited-members"])[0].replace("PreTrainedModel", "AdapterModel").replace("PretrainedModel", "AdapterModel")
cls = vars(sys.modules["adapters"])[cls_name]
# HACK: currently parses head type from name
head_type_str = m.group(1).replace("qa", "question_answering")
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/gptj
classes/models/llama
classes/models/mbart
classes/models/mt5
classes/models/roberta
classes/models/t5
classes/models/vit
Expand Down
2 changes: 1 addition & 1 deletion docs/methods.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Adapter Methods

On this page, we present all adapter methods currently integrated into the `adapters` library.
A tabular overview of adapter methods is provided [here](overview.html#table-of-adapter-methods).
A tabular overview of adapter methods is provided [here](overview.md#table-of-adapter-methods).
Additionally, options to combine multiple adapter methods in a single setup are presented [on the next page](method_combinations.md).

## Bottleneck Adapters
Expand Down
12 changes: 5 additions & 7 deletions docs/prediction_heads.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ We will take a look at the `AdapterModel` classes (e.g. `BertAdapterModel`) intr
```{eval-rst}
.. tip::
We recommend to use the `AdapterModel classes <#adaptermodel-classes>`_ whenever possible.
They have been created specifically for working with adapters and provide more flexibility.
These **flexible** models have been created specifically for working with adapters.
```

## AdapterModel classes
Expand All @@ -18,16 +18,14 @@ First, we load pre-trained model from the Hugging Face Hub via the [`AutoAdapter
model = AutoAdapterModel.from_pretrained("bert-base-uncased")
```

By default, this model doesn't have any heads yet. We add a new one in the next step:
By default, this model doesn't have any heads yet, so let's add a new binary sequence classification head on top of our model:
```python
model.add_classification_head("mrpc", num_labels=2)
```
The line above adds a binary sequence classification head on top of our model.
Because this head is named, we could add multiple other heads with different names to the same model.
This is especially useful if used together with matching adapter modules.
To learn more about the different head types and the configuration options, please refer to the class references of the respective model classes, e.g. [`BertAdapterModel`](adapters.BertAdapterModel).
All heads have a name, we called this new head `"mrpc"`. Since all heads are named, we can add multiple other heads with different names to the same model.
To see the head types of a model and how they can get configured, please refer to the class references of the respective model classes, e.g. [`BertAdapterModel`](adapters.BertAdapterModel).

Now, of course, we would like to train our classification head together with an adapter, so let's add one:
A head alone is just one layer with very few parameters. Hence, we want to train our classification head together with an adapter, so let's add one:
```python
model.add_adapter("mrpc", config="seq_bn")
model.set_active_adapters("mrpc")
Expand Down
3 changes: 2 additions & 1 deletion docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,5 @@ model.delete_adapter(adapter_name)

_We also have a Quickstart Colab notebook for adapter training:_ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb)

For more examples on training different adapter setups, refer to the section on [Adapter Training](training.md).
For more examples of training different adapter setups, refer to the section on [Adapter Training](training.md).
Further information on using adapters with prediction heads can be found in the [Prediction Heads](prediction_heads.md) section.
2 changes: 1 addition & 1 deletion docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ model.set_active_adapters(task_name)

### Step D - Switch to `AdapterTrainer` class

Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](transformers.adapters.AdapterTrainer) class that is optimized for training adapter methods.
Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](adapters.trainer.AdapterTrainer) class that is optimized for training adapter methods.
See [below for more information](#adaptertrainer).

Technically, this change is not required as no changes to the training loop are required for training adapters.
Expand Down
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 1515 files
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@
"sacremoses",
"scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92",
"sphinx-copybutton",
"sphinx-markdown-tables",
"sphinx-copybutton==0.5.2",
"sphinx-markdown-tables==0.0.17",
"sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style.
"sphinx==3.2.1",
"sphinx==5.0.2",
"sphinxext-opengraph==0.4.1",
"sphinx-intl",
"sphinx-multiversion",
"sphinx-intl==2.1.0",
"sphinx-multiversion==0.2.4",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.36.0",
"transformers~=4.38.1",
]


Expand Down
12 changes: 12 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@
},
"layers": [None, "qa_outputs"],
},
# T5
"T5ForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
Expand Down Expand Up @@ -526,6 +527,7 @@
"classification_head.out_proj",
],
},
# DeBERTaV2
"DebertaV2ForSequenceClassification": {
"config": {
"head_type": "classification",
Expand Down Expand Up @@ -575,6 +577,7 @@
},
"layers": [None, "pooler.dense", None, None, "classifier"],
},
# DeBERTa
"DebertaForSequenceClassification": {
"config": {
"head_type": "classification",
Expand Down Expand Up @@ -641,6 +644,15 @@
},
"layers": ["lm_head"],
},
"LlamaForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"layers": 1,
"activation_function": None,
},
"layers": [None, "qa_outputs"],
},
# Electra
"ElectraForTokenClassification": {
"config": {
"head_type": "tagging",
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def _save_adapter_card(
metrics: Optional[List[str]] = None,
**kwargs
):
all_tags = {"adapter-transformers"} # TODO: change this tag once changed on HF side
# Key remains "adapter-transformers", see: https://github.com/huggingface/huggingface.js/pull/459
all_tags = {"adapter-transformers"}
datasets = set()
# Dataset/ Task info
dataset_name = None
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .distilbert.mixin_distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin
from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaModelAdapterMixin
from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin
from .t5.mixin_t5 import (
T5BlockAdaptersMixin,
T5ForCondiditionalGenerationWithHeadsMixin,
Expand Down Expand Up @@ -83,4 +83,5 @@
"BertGenerationEncoder": BertModelAdaptersMixin,
"BertGenerationLayer": BertLayerAdaptersMixin,
"LlamaModel": LlamaModelAdapterMixin,
"LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin,
}
4 changes: 2 additions & 2 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
BART_START_DOCSTRING,
BartConfig,
BartModel,
BartPretrainedModel,
BartPreTrainedModel,
shift_tokens_right,
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward
Expand All @@ -18,7 +18,7 @@
@add_start_docstrings(
"BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING
)
class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel):
class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPreTrainedModel):
_tied_weights_keys = [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
Expand Down
3 changes: 3 additions & 0 deletions src/adapters/models/llama/adapter_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional

import torch

Expand Down Expand Up @@ -58,6 +59,7 @@ def forward(
past_key_values=None,
inputs_embeds=None,
use_cache=None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Expand All @@ -79,6 +81,7 @@ def forward(
position_ids=position_ids,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
output_attentions=output_attentions,
return_dict=return_dict,
output_hidden_states=output_hidden_states,
Expand Down
6 changes: 6 additions & 0 deletions src/adapters/models/llama/mixin_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ def post_embedding_forward(self, module, args, embedding_output):
embedding_output = self.invertible_adapters_forward(embedding_output)
# Prompt tuning not yet supported
return embedding_output


class LlamaForQuestionAnsweringAdapterMixin:
# this is needed because Transformers v4.38.1 is inconsistent with the naming of the base model but didn't change the base_model_prefix
# TODO: remove this when the inconsistency is fixed and remove the LlamaForQuestionAnsweringAdapterMixin from `src/adapters/models/__init__.py`
base_model_prefix = "transformer"
Loading

0 comments on commit a9152e7

Please sign in to comment.