From c6a06df78ad3e5c445670eac4585cad39b2e51a4 Mon Sep 17 00:00:00 2001 From: Jannik Brinkmann <62884101+jannik-brinkmann@users.noreply.github.com> Date: Thu, 13 Oct 2022 14:07:34 +0200 Subject: [PATCH] Add adapter support to BEiT (#428) --- adapter_docs/classes/models/beit.rst | 27 +++++ adapter_docs/index.rst | 1 + adapter_docs/model_overview.md | 1 + src/transformers/__init__.py | 2 + src/transformers/adapters/__init__.py | 2 + src/transformers/adapters/head_utils.py | 10 ++ src/transformers/adapters/mixins/beit.py | 38 ++++++ .../adapters/models/auto/adapter_model.py | 1 + .../adapters/models/beit/__init__.py | 39 ++++++ .../adapters/models/beit/adapter_model.py | 111 ++++++++++++++++++ .../adapters/wrappers/configuration.py | 1 + src/transformers/models/beit/modeling_beit.py | 68 ++++++++--- tests_adapters/test_beit.py | 69 +++++++++++ utils/check_adapters.py | 1 + 14 files changed, 352 insertions(+), 19 deletions(-) create mode 100644 adapter_docs/classes/models/beit.rst create mode 100644 src/transformers/adapters/mixins/beit.py create mode 100644 src/transformers/adapters/models/beit/__init__.py create mode 100644 src/transformers/adapters/models/beit/adapter_model.py create mode 100644 tests_adapters/test_beit.py diff --git a/adapter_docs/classes/models/beit.rst b/adapter_docs/classes/models/beit.rst new file mode 100644 index 0000000000..fec1247a42 --- /dev/null +++ b/adapter_docs/classes/models/beit.rst @@ -0,0 +1,27 @@ +Bidirectional Encoder representation from Image Transformers (BEiT) +========================= + +The Bidirectional Encoder representation from Image Transformers (BEiT) model was proposed in `BERT Pre-Training of Image +Transformers `__ by Hangbo Bao, Li Dong, Songhao Piao, Furu Wei. + + +The abstract from the paper is the following: + +*We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation +from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image +modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image +patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into +visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training +objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we +directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. +Experimental results on image classification and semantic segmentation show that our model achieves competitive results +with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, +significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains +86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).* + +BeitAdapterModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.adapters.BeitAdapterModel + :members: + :inherited-members: BeitPreTrainedModel diff --git a/adapter_docs/index.rst b/adapter_docs/index.rst index d4729ed9f4..d4b0b30889 100644 --- a/adapter_docs/index.rst +++ b/adapter_docs/index.rst @@ -51,6 +51,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model model_overview classes/models/auto classes/models/bart + classes/models/beit classes/models/bert classes/models/deberta classes/models/deberta_v2 diff --git a/adapter_docs/model_overview.md b/adapter_docs/model_overview.md index 9d5eb154cc..f566c7e3e9 100644 --- a/adapter_docs/model_overview.md +++ b/adapter_docs/model_overview.md @@ -13,6 +13,7 @@ The table below further shows which model architectures support which adaptation | Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | | --------------------------------------- | -| - | - | - | - | - | - | | [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | [BERT](classes/models/bert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [DeBERTa](classes/models/deberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [DeBERTa-v2](classes/models/debertaV2.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5ccbc29a36..f28c722ba5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2025,6 +2025,7 @@ "AutoModelWithHeads", "BartAdapterModel", "BartModelWithHeads", + "BeitAdapterModel", "BertAdapterModel", "BertModelWithHeads", "CompacterConfig", @@ -4564,6 +4565,7 @@ AutoModelWithHeads, BartAdapterModel, BartModelWithHeads, + BeitAdapterModel, BertAdapterModel, BertModelWithHeads, CompacterConfig, diff --git a/src/transformers/adapters/__init__.py b/src/transformers/adapters/__init__.py index 9feed61353..9df65b7ff3 100644 --- a/src/transformers/adapters/__init__.py +++ b/src/transformers/adapters/__init__.py @@ -95,6 +95,7 @@ "BartAdapterModel", "BartModelWithHeads", ], + "models.beit": ["BeitAdapterModel"], "models.bert": [ "BertAdapterModel", "BertModelWithHeads", @@ -203,6 +204,7 @@ ) from .models.auto import ADAPTER_MODEL_MAPPING, MODEL_WITH_HEADS_MAPPING, AutoAdapterModel, AutoModelWithHeads from .models.bart import BartAdapterModel, BartModelWithHeads + from .models.beit import BeitAdapterModel from .models.bert import BertAdapterModel, BertModelWithHeads from .models.deberta import DebertaAdapterModel from .models.debertaV2 import DebertaV2AdapterModel diff --git a/src/transformers/adapters/head_utils.py b/src/transformers/adapters/head_utils.py index adcaf62304..23723f07b9 100644 --- a/src/transformers/adapters/head_utils.py +++ b/src/transformers/adapters/head_utils.py @@ -9,6 +9,16 @@ # The "layers" attributes in the configs below map from static head module names to flex head module names. # In this context, "None" refers to a flex-head layer without weights (e.g. dropout, acts). STATIC_TO_FLEX_HEAD_MAP = { + # BEIT + "BeitForImageClassification": { + "config": { + "head_type": "image_classification", + "layers": 1, + "activation_function": None, + "use_pooler": True, + }, + "layers": {"classifier"}, + }, # BERT "BertForSequenceClassification": { "config": { diff --git a/src/transformers/adapters/mixins/beit.py b/src/transformers/adapters/mixins/beit.py new file mode 100644 index 0000000000..d045264047 --- /dev/null +++ b/src/transformers/adapters/mixins/beit.py @@ -0,0 +1,38 @@ +import logging +from typing import Iterable, Tuple + +import torch.nn as nn + +from ..layer import AdapterLayer +from ..model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin + + +logger = logging.getLogger(__name__) + + +class BeitLayerAdaptersMixin: + """Adds adapters to the BeitLayer module.""" + + def _init_adapter_modules(self): + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.attention_adapters._init_adapter_modules() + + +class BeitOutputAdaptersMixin: + """Adds adapters to the BeitOutput module.""" + + def _init_adapter_modules(self): + self.output_adapters = AdapterLayer("output_adapter", self.config) + self.output_adapters._init_adapter_modules() + + +class BeitModelAdaptersMixin(ModelAdaptersMixin): + """Adds adapters to the BeitModel module.""" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.encoder.layer): + yield i, layer + + +class BeitModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/models/auto/adapter_model.py b/src/transformers/adapters/models/auto/adapter_model.py index 7d403f72cd..3fb56a3027 100644 --- a/src/transformers/adapters/models/auto/adapter_model.py +++ b/src/transformers/adapters/models/auto/adapter_model.py @@ -10,6 +10,7 @@ [ ("xlm-roberta", "XLMRobertaAdapterModel"), ("roberta", "RobertaAdapterModel"), + ("beit", "BeitAdapterModel"), ("bert", "BertAdapterModel"), ("distilbert", "DistilBertAdapterModel"), ("deberta-v2", "DebertaV2AdapterModel"), diff --git a/src/transformers/adapters/models/beit/__init__.py b/src/transformers/adapters/models/beit/__init__.py new file mode 100644 index 0000000000..0472ad79d1 --- /dev/null +++ b/src/transformers/adapters/models/beit/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ....utils import _LazyModule + + +_import_structure = { + "adapter_model": ["BeitAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import BeitAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/transformers/adapters/models/beit/adapter_model.py b/src/transformers/adapters/models/beit/adapter_model.py new file mode 100644 index 0000000000..c3c1260f78 --- /dev/null +++ b/src/transformers/adapters/models/beit/adapter_model.py @@ -0,0 +1,111 @@ +from typing import Optional + +import torch + +from ....models.beit.modeling_beit import BEIT_INPUTS_DOCSTRING, BEIT_START_DOCSTRING, BeitModel, BeitPreTrainedModel +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...context import AdapterSetup +from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin + + +@add_start_docstrings( + """Beit Model transformer with the option to add multiple flexible heads on top.""", + BEIT_START_DOCSTRING, +) +class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.beit = BeitModel(config) + + self._init_head_modules() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.beit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + ) + + # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads + if not return_dict: + head_inputs = (outputs[0],) + outputs[2:] + else: + head_inputs = outputs + pooled_output = outputs[1] + + if head or AdapterSetup.get_context_head_setup() or self.active_head: + head_outputs = self.forward_head( + head_inputs, + cls_output=pooled_output, # BEiT does classification based on average-pooling of last hidden state + head_name=head, + return_dict=return_dict, + pooled_output=pooled_output, + **kwargs, + ) + return head_outputs + else: + # in case no head is used just return the output of the base model (including pooler output) + return outputs + + head_types = { + "image_classification": ImageClassificationHead, + } + + def add_image_classification_head( + self, + head_name, + num_labels=2, + layers=1, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + use_pooler=True, + ): + """ + Adds an image classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 1. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + head = ImageClassificationHead( + self, + head_name, + num_labels=num_labels, + layers=layers, + activation_function=activation_function, + multilabel=multilabel, + id2label=id2label, + use_pooler=use_pooler, + ) + self.add_prediction_head(head, overwrite_ok) diff --git a/src/transformers/adapters/wrappers/configuration.py b/src/transformers/adapters/wrappers/configuration.py index 3ffb3f22de..d58ebff7a6 100644 --- a/src/transformers/adapters/wrappers/configuration.py +++ b/src/transformers/adapters/wrappers/configuration.py @@ -10,6 +10,7 @@ "hidden_dropout_prob": "dropout", "attention_probs_dropout_prob": "attention_dropout", }, + "beit": {}, "bert": {}, "distilbert": { "hidden_dropout_prob": "dropout", diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 9a1ca6c9de..de9f4591fa 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -26,6 +26,15 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...adapters.context import ForwardContext +from ...adapters.lora import Linear as LoRALinear +from ...adapters.mixins.beit import ( + BeitLayerAdaptersMixin, + BeitModelAdaptersMixin, + BeitModelWithHeadsAdaptersMixin, + BeitOutputAdaptersMixin, +) +from ...adapters.prefix_tuning import PrefixTuningShim from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -209,7 +218,9 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class BeitSelfAttention(nn.Module): - def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + def __init__( + self, config: BeitConfig, window_size: Optional[tuple] = None, location_key: Optional[str] = None + ) -> None: super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -221,9 +232,9 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> N self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) - self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.query = LoRALinear(config.hidden_size, self.all_head_size, "selfattn", config, attn_key="q") + self.key = LoRALinear(config.hidden_size, self.all_head_size, "selfattn", config, attn_key="k", bias=False) + self.value = LoRALinear(config.hidden_size, self.all_head_size, "selfattn", config, attn_key="v") self.dropout = nn.Dropout(config.attention_probs_dropout_prob) @@ -232,6 +243,8 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> N else: self.relative_position_bias = None + self.prefix_tuning = PrefixTuningShim(location_key + "_prefix" if location_key else None, config) + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -250,6 +263,8 @@ def forward( value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states) + # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -299,14 +314,15 @@ def __init__(self, config: BeitConfig) -> None: def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - return hidden_states class BeitAttention(nn.Module): - def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: + def __init__( + self, config: BeitConfig, window_size: Optional[tuple] = None, location_key: Optional[str] = None + ) -> None: super().__init__() - self.attention = BeitSelfAttention(config, window_size=window_size) + self.attention = BeitSelfAttention(config, window_size=window_size, location_key=location_key) self.output = BeitSelfOutput(config) self.pruned_heads = set() @@ -346,7 +362,7 @@ def forward( class BeitIntermediate(nn.Module): def __init__(self, config: BeitConfig) -> None: super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dense = LoRALinear(config.hidden_size, config.intermediate_size, "intermediate", config) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: @@ -355,31 +371,35 @@ def __init__(self, config: BeitConfig) -> None: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states -class BeitOutput(nn.Module): +class BeitOutput(BeitOutputAdaptersMixin, nn.Module): def __init__(self, config: BeitConfig) -> None: super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.config = config + + self.dense = LoRALinear(config.intermediate_size, config.hidden_size, "output", config) self.dropout = nn.Dropout(config.hidden_dropout_prob) + self._init_adapter_modules() - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - + hidden_states = self.output_adapters.adapter_layer_forward(hidden_states, input_tensor, None) return hidden_states -class BeitLayer(nn.Module): +class BeitLayer(BeitLayerAdaptersMixin, nn.Module): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None: super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = BeitAttention(config, window_size=window_size) + self.attention = BeitAttention(config, window_size=window_size, location_key="self") self.intermediate = BeitIntermediate(config) self.output = BeitOutput(config) self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -393,6 +413,8 @@ def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop else: self.lambda_1, self.lambda_2 = None, None + self._init_adapter_modules() + def forward( self, hidden_states: torch.Tensor, @@ -409,6 +431,8 @@ def forward( attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + hidden_states = self.attention_adapters.adapter_layer_forward(attention_output, hidden_states, None) + # apply lambda_1 if present if self.lambda_1 is not None: attention_output = self.lambda_1 * attention_output @@ -420,7 +444,7 @@ def forward( layer_output = self.layernorm_after(hidden_states) layer_output = self.intermediate(layer_output) - layer_output = self.output(layer_output) + layer_output = self.output(layer_output, hidden_states) if self.lambda_2 is not None: layer_output = self.lambda_2 * layer_output @@ -617,7 +641,7 @@ def _set_gradient_checkpointing(self, module, value=False): "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", BEIT_START_DOCSTRING, ) -class BeitModel(BeitPreTrainedModel): +class BeitModel(BeitModelAdaptersMixin, BeitPreTrainedModel): def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None: super().__init__(config) self.config = config @@ -630,12 +654,17 @@ def __init__(self, config: BeitConfig, add_pooling_layer: bool = True) -> None: ) self.pooler = BeitPooler(config) if add_pooling_layer else None + self._init_adapter_modules() + # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embeddings.patch_embeddings + def set_input_embeddings(self, value): + self.embeddings.patch_embeddings = value + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base @@ -653,6 +682,7 @@ class PreTrainedModel modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, ) + @ForwardContext.wrap def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -729,7 +759,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: will need to use [`BeitForMaskedImageModeling`] directly if you wish to do masked image modeling with BEiT.""", BEIT_START_DOCSTRING, ) -class BeitForMaskedImageModeling(BeitPreTrainedModel): +class BeitForMaskedImageModeling(BeitModelWithHeadsAdaptersMixin, BeitPreTrainedModel): def __init__(self, config: BeitConfig) -> None: super().__init__(config) @@ -829,7 +859,7 @@ def forward( """, BEIT_START_DOCSTRING, ) -class BeitForImageClassification(BeitPreTrainedModel): +class BeitForImageClassification(BeitModelWithHeadsAdaptersMixin, BeitPreTrainedModel): def __init__(self, config: BeitConfig) -> None: super().__init__(config) diff --git a/tests_adapters/test_beit.py b/tests_adapters/test_beit.py new file mode 100644 index 0000000000..6a5f5fdbd7 --- /dev/null +++ b/tests_adapters/test_beit.py @@ -0,0 +1,69 @@ +import unittest + +from tests.models.beit.test_modeling_beit import * +from transformers import BeitAdapterModel +from transformers.testing_utils import require_torch + +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import VisionAdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_composition import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin +from .test_common import AdapterModelTesterMixin + + +@require_torch +class BeitAdapterModelTest(AdapterModelTesterMixin, BeitModelTest): + all_model_classes = ( + BeitAdapterModel, + ) + fx_compatible = False + + +class BeitAdapterTestBase(VisionAdapterTestBase): + config_class = BeitConfig + config = make_config( + BeitConfig, + image_size=224, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + feature_extractor_name = 'microsoft/beit-base-patch16-224-pt22k' + + +@require_torch +class BeitAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + PredictionHeadModelTestMixin, + BeitAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class BeitClassConversionTest( + ModelClassConversionTestMixin, + BeitAdapterTestBase, + unittest.TestCase, +): + pass diff --git a/utils/check_adapters.py b/utils/check_adapters.py index 3c34fb39cc..e371189c6e 100644 --- a/utils/check_adapters.py +++ b/utils/check_adapters.py @@ -5,6 +5,7 @@ MODELS_WITH_ADAPTERS = [ "bert", + "beit", "roberta", "xlm_roberta", "distilbert",