Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adapters to Vision Transformer #363

Merged
merged 11 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
- name: Install
run: |
pip install torch
pip install .[sklearn,testing,sentencepiece]
pip install .[sklearn,testing,sentencepiece,vision]
pip install datasets
- name: Test
run: |
Expand Down
27 changes: 27 additions & 0 deletions adapter_docs/classes/models/vit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Vision Transformer (ViT)
=========================

The Vision Transformer (ViT) model was proposed in `An Image is Worth 16x16 Words: Transformers for Image Recognition
at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk
Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob
Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining
very good results compared to familiar convolutional architectures.


The abstract from the paper is the following:

*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its
applications to computer vision remain limited. In vision, attention is either applied in conjunction with
convolutional networks, or used to replace certain components of convolutional networks while keeping their overall
structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to
sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of
data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.),
Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring
substantially fewer computational resources to train.*

ViTAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.adapters.ViTAdapterModel
:members:
:inherited-members: ViTPreTrainedModel
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/mbart
classes/models/roberta
classes/models/t5
classes/models/vit
classes/models/xlmroberta

.. toctree::
Expand Down
1 change: 1 addition & 0 deletions adapter_docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The table below further shows which model architectures support which adaptation
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |

(*) If the used encoder and decoder model class are supported.
Expand Down
2 changes: 1 addition & 1 deletion adding_adapters_to_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ For this purpose, there typically exists a module `src/transformers/adapters/mix

- Add a new `<model_type>.py` module for your architecture in `src/transformers/adapters/mixins` (or reuse an existing if possible).
- There usually exists a mixin on the Transformer layer level that derives that holds modules for adapter layers.
- The mixin for the whole base model class (e.g. `BertModel`) should derive from `ModelAdaptersMixin` and (if possible) `InvertibleAdaptersMixin`. This mixin should at least implement the `iter_layers()` method but might require additional modifications depending on the architecture.
- The mixin for the whole base model class (e.g. `BertModel`) should derive from `ModelAdaptersMixin` and (if possible) `EmbeddingAdaptersMixin` and/or `InvertibleAdaptersMixin`. This mixin should at least implement the `iter_layers()` method but might require additional modifications depending on the architecture.
- Have a look at existing examples, e.g. `distilbert.py`, `bert.py`.
- Implement the mixins and the required modifications on the modeling classes (`modeling_<model_type>.py`).
- Make sure the calls to `adapter_layer_forward()` are added in the right places.
Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Currently, scripts for these tasks support adapters:

| Task | Description |
| --- | --- |
| [**`language-modeling`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/language-modeling) | Causal & Masked language modeling
| [**`language-modeling`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/language-modeling) | Causal & Masked language modeling
| [**`multiple-choice`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/multiple-choice) | SWAG Dataset
| [**`question-answering`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/question-answering) | SQuAD-style QA
| [**`summarization`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/seq2seq) | Summarization, e.g. on CNN/Dailymail or XSum
Expand All @@ -45,6 +45,7 @@ Currently, scripts for these tasks support adapters:
| [**`token-classification`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/token-classification) | NER, e.g. on CoNLL2003
| [**`translation`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/seq2seq) | Machine translation, e.g. on WMT tasks
| [**`dependency-parsing`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/dependency-parsing) | Dependency parsing on Universal Dependencies
| [**`image-classification`**](https://github.com/adapter-hub/adapter-transformers/tree/master/examples/pytorch/image-classification) | Image classification, e.g. on CIFAR-10/-100

All scripts listed above which can be used for training provide a new `--train_adapter` option that switches between full fine-tuning and adapter training.
Loading pre-trained adapters can be done via `--load_adapter`.
Expand Down
56 changes: 50 additions & 6 deletions examples/pytorch/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@
import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AdapterArguments,
AdapterConfig,
AdapterTrainer,
AutoAdapterModel,
AutoConfig,
AutoFeatureExtractor,
AutoModelForImageClassification,
HfArgumentParser,
Trainer,
TrainingArguments,
Expand Down Expand Up @@ -157,13 +160,15 @@ def main():
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, AdapterArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
model_args, data_args, training_args, adapter_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args, data_args, training_args, adapter_args = parser.parse_args_into_dataclasses()

# Setup logging
logging.basicConfig(
Expand Down Expand Up @@ -256,21 +261,59 @@ def compute_metrics(p):
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForImageClassification.from_pretrained(
model = AutoAdapterModel.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model.add_image_classification_head(
data_args.dataset_name or "img_clf",
num_labels=len(labels),
id2label=id2label,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)

# Setup adapters
if adapter_args.train_adapter:
task_name = data_args.dataset_name or "img_clf"
# check if adapter already exists, otherwise add it
if task_name not in model.config.adapters:
# resolve the adapter config
adapter_config = AdapterConfig.load(
adapter_args.adapter_config,
non_linearity=adapter_args.adapter_non_linearity,
reduction_factor=adapter_args.adapter_reduction_factor,
)
# load a pre-trained from Hub if specified
if adapter_args.load_adapter:
model.load_adapter(
adapter_args.load_adapter,
config=adapter_config,
load_as=task_name,
)
# otherwise, add a fresh adapter
else:
model.add_adapter(task_name, config=adapter_config)

# Freeze all model weights except of those of this adapter
model.train_adapter([task_name])
# Set the adapters to be used in every forward pass
model.set_active_adapters([task_name])
else:
if adapter_args.load_adapter:
raise ValueError(
"Adapters can only be loaded in adapters training mode."
"Use --train_adapter to enable adapter training"
)

# Define torchvision transforms to be applied to each image.
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
Expand Down Expand Up @@ -323,7 +366,8 @@ def val_transforms(example_batch):
dataset["validation"].set_transform(val_transforms)

# Initalize our trainer
trainer = Trainer(
trainer_class = AdapterTrainer if adapter_args.train_adapter else Trainer
trainer = trainer_class(
model=model,
args=training_args,
train_dataset=dataset["train"] if training_args.do_train else None,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,7 @@
"DistilBertAdapterModel",
"DistilBertModelWithHeads",
"DynamicAdapterFusionConfig",
"EmbeddingAdaptersMixin",
"ForwardContext",
"GPT2AdapterModel",
"GPT2ModelWithHeads",
Expand All @@ -1846,6 +1847,7 @@
"StaticAdapterFusionConfig",
"T5AdapterModel",
"T5ModelWithHeads",
"ViTAdapterModel",
"XLMRobertaAdapterModel",
"XLMRobertaModelWithHeads",
"get_adapter_config_hash",
Expand Down Expand Up @@ -4115,6 +4117,7 @@
DistilBertAdapterModel,
DistilBertModelWithHeads,
DynamicAdapterFusionConfig,
EmbeddingAdaptersMixin,
ForwardContext,
GPT2AdapterModel,
GPT2ModelWithHeads,
Expand All @@ -4140,6 +4143,7 @@
StaticAdapterFusionConfig,
T5AdapterModel,
T5ModelWithHeads,
ViTAdapterModel,
XLMRobertaAdapterModel,
XLMRobertaModelWithHeads,
get_adapter_config_hash,
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
],
"layer": ["AdapterLayer", "AdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
"InvertibleAdaptersMixin",
"ModelAdaptersMixin",
"ModelWithHeadsAdaptersMixin",
Expand Down Expand Up @@ -118,6 +119,7 @@
"T5AdapterModel",
"T5ModelWithHeads",
],
"models.vit": ["ViTAdapterModel"],
"models.xlm_roberta": [
"XLMRobertaAdapterModel",
"XLMRobertaModelWithHeads",
Expand Down Expand Up @@ -189,7 +191,12 @@
TaggingHead,
)
from .layer import AdapterLayer, AdapterLayerBase
from .model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin, ModelWithHeadsAdaptersMixin
from .model_mixin import (
EmbeddingAdaptersMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
from .models.auto import ADAPTER_MODEL_MAPPING, MODEL_WITH_HEADS_MAPPING, AutoAdapterModel, AutoModelWithHeads
from .models.bart import BartAdapterModel, BartModelWithHeads
from .models.bert import BertAdapterModel, BertModelWithHeads
Expand All @@ -200,6 +207,7 @@
from .models.mbart import MBartAdapterModel, MBartModelWithHeads
from .models.roberta import RobertaAdapterModel, RobertaModelWithHeads
from .models.t5 import T5AdapterModel, T5ModelWithHeads
from .models.vit import ViTAdapterModel
from .models.xlm_roberta import XLMRobertaAdapterModel, XLMRobertaModelWithHeads
from .trainer import AdapterTrainer, Seq2SeqAdapterTrainer
from .training import AdapterArguments, MultiLingAdapterArguments
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,19 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], b
# Some composition blocks might not be supported by all models.
# Add a whitelist of models for those here.
SUPPORTED_MODELS = {
Parallel: ["bert", "roberta", "distilbert", "deberta-v2", "deberta", "bart", "mbart", "gpt2", "t5", "xlm-roberta"],
Parallel: [
"bert",
"roberta",
"distilbert",
"deberta-v2",
"deberta",
"bart",
"mbart",
"gpt2",
"t5",
"vit",
"xlm-roberta",
],
}


Expand Down
9 changes: 9 additions & 0 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@
"cls.predictions.decoder",
],
},
# ViT
"ViTForImageClassification": {
"config": {
"head_type": "image_classification",
"layers": 1,
"activation_function": None,
},
"layers": {"classifier"},
},
}


Expand Down
69 changes: 66 additions & 3 deletions src/transformers/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...modeling_outputs import (
ImageClassifierOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Seq2SeqModelOutput,
Expand Down Expand Up @@ -412,6 +413,67 @@ def get_label_names(self):
return ["start_positions", "end_positions"]


class ImageClassificationHead(PredictionHead):
def __init__(
self,
model,
head_name,
num_labels=2,
layers=2,
activation_function="tanh",
multilabel=False,
id2label=None,
use_pooler=False,
bias=True,
):
super().__init__(head_name)
self.config = {
"head_type": "image_classification",
"num_labels": num_labels,
"layers": layers,
"activation_function": activation_function,
"multilabel": multilabel,
"label2id": {label: id_ for id_, label in id2label.items()} if id2label is not None else None,
"use_pooler": use_pooler,
"bias": bias,
}
self.build(model)

def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=False, **kwargs):
if cls_output is None:
if self.config["use_pooler"]:
cls_output = kwargs.pop("pooled_output")
else:
cls_output = outputs[0][:, 0]
logits = super().forward(cls_output)
loss = None
labels = kwargs.pop("labels", None)
if labels is not None:
if self.config["num_labels"] == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
elif self.config["multilabel"]:
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config["num_labels"]), labels.view(-1))

if return_dict:
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
outputs = (logits,) + outputs[1:]
if labels is not None:
outputs = (loss,) + outputs
return outputs


class ModelWithFlexibleHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
"""
Adds flexible prediction heads to a model class. Implemented by the XModelWithHeads classes.
Expand Down Expand Up @@ -693,9 +755,10 @@ def _get_head_input(outputs, cls_out, batch):
return inputs, cls_input

# Pass invertible adapter if we have one
inv_adapter = self.base_model.get_invertible_adapter()
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter
if hasattr(self.base_model, "get_invertible_adapter"):
inv_adapter = self.base_model.get_invertible_adapter()
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter

for head in used_heads:
if head not in self.heads:
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/adapters/mixins/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin
from ..model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin


class BartEncoderLayerAdaptersMixin:
Expand All @@ -25,7 +25,7 @@ def _init_adapter_modules(self):
self.cross_attention_adapters._init_adapter_modules()


class BartModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin):
class BartModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin):
"""Adds adapters to the BartModel class."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
Expand Down
Loading