Skip to content

Commit

Permalink
Add adapter support to BEiT (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
jannik-brinkmann authored Oct 13, 2022
1 parent 2b6771f commit c6a06df
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 19 deletions.
27 changes: 27 additions & 0 deletions adapter_docs/classes/models/beit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Bidirectional Encoder representation from Image Transformers (BEiT)
=========================

The Bidirectional Encoder representation from Image Transformers (BEiT) model was proposed in `BERT Pre-Training of Image
Transformers <https://arxiv.org/abs/2106.08254>`__ by Hangbo Bao, Li Dong, Songhao Piao, Furu Wei.


The abstract from the paper is the following:

*We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation
from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image
modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image
patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into
visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training
objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we
directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder.
Experimental results on image classification and semantic segmentation show that our model achieves competitive results
with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K,
significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains
86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).*

BeitAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.adapters.BeitAdapterModel
:members:
:inherited-members: BeitPreTrainedModel
1 change: 1 addition & 0 deletions adapter_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
model_overview
classes/models/auto
classes/models/bart
classes/models/beit
classes/models/bert
classes/models/deberta
classes/models/deberta_v2
Expand Down
1 change: 1 addition & 0 deletions adapter_docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The table below further shows which model architectures support which adaptation
| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block |
| --------------------------------------- | -| - | - | - | - | - | - |
| [BART](classes/models/bart.html) ||||||||
| [BEIT](classes/models/beit.html) |||||| | |
| [BERT](classes/models/bert.html) ||||||||
| [DeBERTa](classes/models/deberta.html) ||||||||
| [DeBERTa-v2](classes/models/debertaV2.html) ||||||||
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,7 @@
"AutoModelWithHeads",
"BartAdapterModel",
"BartModelWithHeads",
"BeitAdapterModel",
"BertAdapterModel",
"BertModelWithHeads",
"CompacterConfig",
Expand Down Expand Up @@ -4564,6 +4565,7 @@
AutoModelWithHeads,
BartAdapterModel,
BartModelWithHeads,
BeitAdapterModel,
BertAdapterModel,
BertModelWithHeads,
CompacterConfig,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
"BartAdapterModel",
"BartModelWithHeads",
],
"models.beit": ["BeitAdapterModel"],
"models.bert": [
"BertAdapterModel",
"BertModelWithHeads",
Expand Down Expand Up @@ -203,6 +204,7 @@
)
from .models.auto import ADAPTER_MODEL_MAPPING, MODEL_WITH_HEADS_MAPPING, AutoAdapterModel, AutoModelWithHeads
from .models.bart import BartAdapterModel, BartModelWithHeads
from .models.beit import BeitAdapterModel
from .models.bert import BertAdapterModel, BertModelWithHeads
from .models.deberta import DebertaAdapterModel
from .models.debertaV2 import DebertaV2AdapterModel
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@
# The "layers" attributes in the configs below map from static head module names to flex head module names.
# In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts).
STATIC_TO_FLEX_HEAD_MAP = {
# BEIT
"BeitForImageClassification": {
"config": {
"head_type": "image_classification",
"layers": 1,
"activation_function": None,
"use_pooler": True,
},
"layers": {"classifier"},
},
# BERT
"BertForSequenceClassification": {
"config": {
Expand Down
38 changes: 38 additions & 0 deletions src/transformers/adapters/mixins/beit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin


logger = logging.getLogger(__name__)


class BeitLayerAdaptersMixin:
"""Adds adapters to the BeitLayer module."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.attention_adapters._init_adapter_modules()


class BeitOutputAdaptersMixin:
"""Adds adapters to the BeitOutput module."""

def _init_adapter_modules(self):
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.output_adapters._init_adapter_modules()


class BeitModelAdaptersMixin(ModelAdaptersMixin):
"""Adds adapters to the BeitModel module."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
for i, layer in enumerate(self.encoder.layer):
yield i, layer


class BeitModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin):
pass
1 change: 1 addition & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[
("xlm-roberta", "XLMRobertaAdapterModel"),
("roberta", "RobertaAdapterModel"),
("beit", "BeitAdapterModel"),
("bert", "BertAdapterModel"),
("distilbert", "DistilBertAdapterModel"),
("deberta-v2", "DebertaV2AdapterModel"),
Expand Down
39 changes: 39 additions & 0 deletions src/transformers/adapters/models/beit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": ["BeitAdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import BeitAdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
111 changes: 111 additions & 0 deletions src/transformers/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional

import torch

from ....models.beit.modeling_beit import BEIT_INPUTS_DOCSTRING, BEIT_START_DOCSTRING, BeitModel, BeitPreTrainedModel
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...context import AdapterSetup
from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin


@add_start_docstrings(
"""Beit Model transformer with the option to add multiple flexible heads on top.""",
BEIT_START_DOCSTRING,
)
class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.beit = BeitModel(config)

self._init_head_modules()

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
)

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
else:
head_inputs = outputs
pooled_output = outputs[1]

if head or AdapterSetup.get_context_head_setup() or self.active_head:
head_outputs = self.forward_head(
head_inputs,
cls_output=pooled_output, # BEiT does classification based on average-pooling of last hidden state
head_name=head,
return_dict=return_dict,
pooled_output=pooled_output,
**kwargs,
)
return head_outputs
else:
# in case no head is used just return the output of the base model (including pooler output)
return outputs

head_types = {
"image_classification": ImageClassificationHead,
}

def add_image_classification_head(
self,
head_name,
num_labels=2,
layers=1,
activation_function="tanh",
overwrite_ok=False,
multilabel=False,
id2label=None,
use_pooler=True,
):
"""
Adds an image classification head on top of the model.
Args:
head_name (str): The name of the head.
num_labels (int, optional): Number of classification labels. Defaults to 2.
layers (int, optional): Number of layers. Defaults to 1.
activation_function (str, optional): Activation function. Defaults to 'tanh'.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
multilabel (bool, optional): Enable multilabel classification setup. Defaults to False.
"""

head = ImageClassificationHead(
self,
head_name,
num_labels=num_labels,
layers=layers,
activation_function=activation_function,
multilabel=multilabel,
id2label=id2label,
use_pooler=use_pooler,
)
self.add_prediction_head(head, overwrite_ok)
1 change: 1 addition & 0 deletions src/transformers/adapters/wrappers/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"hidden_dropout_prob": "dropout",
"attention_probs_dropout_prob": "attention_dropout",
},
"beit": {},
"bert": {},
"distilbert": {
"hidden_dropout_prob": "dropout",
Expand Down
Loading

0 comments on commit c6a06df

Please sign in to comment.