-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2b6771f
commit c6a06df
Showing
14 changed files
with
352 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
Bidirectional Encoder representation from Image Transformers (BEiT) | ||
========================= | ||
|
||
The Bidirectional Encoder representation from Image Transformers (BEiT) model was proposed in `BERT Pre-Training of Image | ||
Transformers <https://arxiv.org/abs/2106.08254>`__ by Hangbo Bao, Li Dong, Songhao Piao, Furu Wei. | ||
|
||
|
||
The abstract from the paper is the following: | ||
|
||
*We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation | ||
from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image | ||
modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image | ||
patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into | ||
visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training | ||
objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we | ||
directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. | ||
Experimental results on image classification and semantic segmentation show that our model achieves competitive results | ||
with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, | ||
significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains | ||
86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%).* | ||
|
||
BeitAdapterModel | ||
~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.adapters.BeitAdapterModel | ||
:members: | ||
:inherited-members: BeitPreTrainedModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import logging | ||
from typing import Iterable, Tuple | ||
|
||
import torch.nn as nn | ||
|
||
from ..layer import AdapterLayer | ||
from ..model_mixin import ModelAdaptersMixin, ModelWithHeadsAdaptersMixin | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class BeitLayerAdaptersMixin: | ||
"""Adds adapters to the BeitLayer module.""" | ||
|
||
def _init_adapter_modules(self): | ||
self.attention_adapters = AdapterLayer("mh_adapter", self.config) | ||
self.attention_adapters._init_adapter_modules() | ||
|
||
|
||
class BeitOutputAdaptersMixin: | ||
"""Adds adapters to the BeitOutput module.""" | ||
|
||
def _init_adapter_modules(self): | ||
self.output_adapters = AdapterLayer("output_adapter", self.config) | ||
self.output_adapters._init_adapter_modules() | ||
|
||
|
||
class BeitModelAdaptersMixin(ModelAdaptersMixin): | ||
"""Adds adapters to the BeitModel module.""" | ||
|
||
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: | ||
for i, layer in enumerate(self.encoder.layer): | ||
yield i, layer | ||
|
||
|
||
class BeitModelWithHeadsAdaptersMixin(ModelWithHeadsAdaptersMixin): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# flake8: noqa | ||
# There's no way to ignore "F401 '...' imported but unused" warnings in this | ||
# module, but to preserve other warnings. So, don't check this module at all. | ||
|
||
# Copyright 2020 The Adapter-Hub Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from ....utils import _LazyModule | ||
|
||
|
||
_import_structure = { | ||
"adapter_model": ["BeitAdapterModel"], | ||
} | ||
|
||
|
||
if TYPE_CHECKING: | ||
from .adapter_model import BeitAdapterModel | ||
|
||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule( | ||
__name__, | ||
globals()["__file__"], | ||
_import_structure, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from ....models.beit.modeling_beit import BEIT_INPUTS_DOCSTRING, BEIT_START_DOCSTRING, BeitModel, BeitPreTrainedModel | ||
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||
from ...context import AdapterSetup | ||
from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin | ||
|
||
|
||
@add_start_docstrings( | ||
"""Beit Model transformer with the option to add multiple flexible heads on top.""", | ||
BEIT_START_DOCSTRING, | ||
) | ||
class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
|
||
self.beit = BeitModel(config) | ||
|
||
self._init_head_modules() | ||
|
||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) | ||
def forward( | ||
self, | ||
pixel_values: Optional[torch.Tensor] = None, | ||
bool_masked_pos: Optional[torch.BoolTensor] = None, | ||
head_mask: Optional[torch.Tensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
head=None, | ||
output_adapter_gating_scores=False, | ||
output_adapter_fusion_attentions=False, | ||
**kwargs, | ||
): | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
outputs = self.beit( | ||
pixel_values, | ||
bool_masked_pos=bool_masked_pos, | ||
head_mask=head_mask, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
output_adapter_gating_scores=output_adapter_gating_scores, | ||
output_adapter_fusion_attentions=output_adapter_fusion_attentions, | ||
) | ||
|
||
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads | ||
if not return_dict: | ||
head_inputs = (outputs[0],) + outputs[2:] | ||
else: | ||
head_inputs = outputs | ||
pooled_output = outputs[1] | ||
|
||
if head or AdapterSetup.get_context_head_setup() or self.active_head: | ||
head_outputs = self.forward_head( | ||
head_inputs, | ||
cls_output=pooled_output, # BEiT does classification based on average-pooling of last hidden state | ||
head_name=head, | ||
return_dict=return_dict, | ||
pooled_output=pooled_output, | ||
**kwargs, | ||
) | ||
return head_outputs | ||
else: | ||
# in case no head is used just return the output of the base model (including pooler output) | ||
return outputs | ||
|
||
head_types = { | ||
"image_classification": ImageClassificationHead, | ||
} | ||
|
||
def add_image_classification_head( | ||
self, | ||
head_name, | ||
num_labels=2, | ||
layers=1, | ||
activation_function="tanh", | ||
overwrite_ok=False, | ||
multilabel=False, | ||
id2label=None, | ||
use_pooler=True, | ||
): | ||
""" | ||
Adds an image classification head on top of the model. | ||
Args: | ||
head_name (str): The name of the head. | ||
num_labels (int, optional): Number of classification labels. Defaults to 2. | ||
layers (int, optional): Number of layers. Defaults to 1. | ||
activation_function (str, optional): Activation function. Defaults to 'tanh'. | ||
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. | ||
multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. | ||
""" | ||
|
||
head = ImageClassificationHead( | ||
self, | ||
head_name, | ||
num_labels=num_labels, | ||
layers=layers, | ||
activation_function=activation_function, | ||
multilabel=multilabel, | ||
id2label=id2label, | ||
use_pooler=use_pooler, | ||
) | ||
self.add_prediction_head(head, overwrite_ok) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.