diff --git a/docs/source/model_doc/detr.rst b/docs/source/model_doc/detr.rst new file mode 100644 index 00000000000000..d23af0df4a3578 --- /dev/null +++ b/docs/source/model_doc/detr.rst @@ -0,0 +1,105 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +DETR +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The DETR model was proposed in `End-to-End Object Detection with Transformers +`__ by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov +and Sergey Zagoruyko. DETR consists of a convolutional backbone followed by an encoder-decoder Transformer which can be trained +end-to-end for object detection. It greatly simplifies a lot of the complexity of models like Faster-R-CNN and Mask-R-CNN, which +use things like region proposals, non-maximum suppression procedure and anchor generation. Moreover, DETR can also be naturally extended +to perform panoptic segmentation, by simply adding a mask head on top of the decoder outputs. + +The abstract from the paper is the following: + +*We present a new method that views object detection as a direct set prediction problem. Our approach streamlines the detection pipeline, +effectively removing the need for many hand-designed components like a non-maximum suppression procedure or anchor generation that explicitly +encode our prior knowledge about the task. The main ingredients of the new framework, called DEtection TRansformer or DETR, are a set-based +global loss that forces unique predictions via bipartite matching, and a transformer encoder-decoder architecture. Given a fixed small set of +learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions +in parallel. The new model is conceptually simple and does not require a specialized library, unlike many other modern detectors. DETR demonstrates +accuracy and run-time performance on par with the well-established and highly-optimized Faster RCNN baseline on the challenging COCO object detection +dataset. Moreover, DETR can be easily generalized to produce panoptic segmentation in a unified manner. We show that it significantly outperforms +competitive baselines.* + +The original code can be found `here `__. + +Here's a TLDR explaining how the model works: + +First, an image is sent through a pre-trained convolutional backbone (in the paper, the authors use ResNet-50/ResNet-101). Let's assume we also add a +batch dimension. This means that the input to the backbone is a tensor of shape :obj:`(1, 3, height, width)`, assuming the image has 3 color channels (RGB). +The CNN backbone outputs a new lower-resolution feature map, typically of shape :obj:`(1, 2048, height/32, width/32)`. This is then projected to match +the hidden dimension of the Transformer of DETR, which is :obj:`256` by default, using a :obj:`nn.Conv2D` layer. So now, we have a tensor of shape +:obj:`(1, 256, height/32, width/32).` Next, the image is flattened and transposed to obtain a tensor of shape :obj:`(batch_size, seq_len, d_model)` = +:obj:`(1, width/32*height/32, 256)`. So a difference with NLP models is that the sequence length is actually longer than usual, but with a smaller +:obj:`d_model` (which in NLP is typically 768 or higher). + +Next, this is sent through the encoder, outputting :obj:`encoder_hidden_states` of the same shape (you can consider these as image features). Next, so-called +**object queries** are sent through the decoder. This is a tensor of shape :obj:`(batch_size, num_queries, d_model)`, with :obj:`num_queries` typically set +to 100 and is initialized with zeros. Each object query looks for a particular object in the image. Next, the decoder updates these object queries through +multiple self-attention and encoder-decoder attention layers to output :obj:`decoder_hidden_states` of the same shape: :obj:`(batch_size, num_queries, d_model)`. +Next, two heads are added on top for object detection: a linear layer for classifying each object query into one of the objects or "no object", and a MLP +to predict bounding boxes for each query. + +The model is trained using a **bipartite matching loss**: so what we actually do is compare the predicted classes + bounding boxes of each of the +N = 100 object queries to the ground truth annotations, padded up to the same length N (so if an image only contains 4 objects, 96 annotations will +just have a "no object" as class and "no bounding box" as bounding box). The `Hungarian matching algorithm `__ is used to create a one-to-one mapping of +each of the N queries to each of the N annotations. Next, standard cross-entropy and L1 bounding box losses are used to optimize the parameters of +the model. + +Tips: + +- DETR uses so-called **object queries** to detect objects in an image. The number of queries determines the maximum number of objects that + can be detected in a single image, and is set to 100 by default (see parameter :obj:`num_queries` of :class:`~transformers.DetrConfig`). + Note that it's good to have some slack (in COCO, the authors used 100, while the maximum number of objects in a COCO image is ~70). +- The decoder of DETR updates the query embeddings in parallel. This is different from language models like GPT-2, which use autoregressive decoding + instead of parallel. Hence, no causal attention mask is used. +- DETR adds position embeddings to the hidden states at each self-attention and cross-attention layer before projecting to queries and keys. + For the position embeddings of the image, one can choose between fixed sinusoidal or learned absolute position embeddings. By default, + the parameter :obj:`position_embedding_type` of :class:`~transformers.DetrConfig` is set to :obj:`"sine"`. +- During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help the model output the correct + number of objects of each class. If you set the parameter :obj:`auxiliary_loss` of :class:`~transformers.DetrConfig` to :obj:`True`, then prediction + feedforward neural networks and Hungarian losses are added after each decoder layer (with the FFNs sharing parameters). + +DetrConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrConfig + :members: + + +DetrTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrTokenizer + :members: __call__ + + +DetrModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrModel + :members: forward + + +DetrForObjectDetection +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrForObjectDetection + :members: forward + + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 41b18ad6d41394..8f24ebc61d289e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -134,6 +134,7 @@ "Wav2Vec2FeatureExtractor", "Wav2Vec2Processor", ], + "models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrTokenizer"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ @@ -288,6 +289,7 @@ # tokenziers-backed objects if is_tokenizers_available(): # Fast tokenizers + _import_structure["models.detr"].append("DetrTokenizerFast") _import_structure["models.convbert"].append("ConvBertTokenizerFast") _import_structure["models.albert"].append("AlbertTokenizerFast") _import_structure["models.bart"].append("BartTokenizerFast") @@ -376,6 +378,14 @@ _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure + _import_structure["models.detr"].extend( + [ + "DETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "DetrForObjectDetection", + "DetrModel", + ] + ) + _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1297,6 +1307,7 @@ load_tf2_weights_in_pytorch_model, ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig + from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig, DetrTokenizer from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -1452,6 +1463,7 @@ from .utils.dummy_sentencepiece_objects import * if is_tokenizers_available(): + from .models.detr import DetrTokenizerFast from .models.albert import AlbertTokenizerFast from .models.bart import BartTokenizerFast from .models.barthez import BarthezTokenizerFast @@ -1491,6 +1503,12 @@ # Modeling if is_torch_available(): + from .models.detr import ( + DETR_PRETRAINED_MODEL_ARCHIVE_LIST, + DetrForObjectDetection, + DetrModel, + ) + # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f7f9a9e58ded44..3189b1e8156875 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -17,6 +17,7 @@ # limitations under the License. from . import ( + detr, albert, auto, bart, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 338f273757573b..23606e1a934869 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -19,6 +19,7 @@ from ...configuration_utils import PretrainedConfig from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig +from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from ..bert_generation.configuration_bert_generation import BertGenerationConfig @@ -75,6 +76,7 @@ (key, value) for pretrained_map in [ # Add archive maps here + DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LED_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -120,6 +122,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("detr", DetrConfig), ("wav2vec2", Wav2Vec2Config), ("convbert", ConvBertConfig), ("led", LEDConfig), @@ -171,6 +174,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("detr", "Detr"), ("wav2vec2", "Wav2Vec2"), ("convbert", "ConvBERT"), ("led", "LED"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 39e8b70b3ce1ed..67ffcd12758e20 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -23,6 +23,10 @@ from ...utils import logging # Add modeling imports here +from ..detr.modeling_detr import ( + DetrForObjectDetection, + DetrModel, +) from ..albert.modeling_albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, @@ -68,6 +72,10 @@ ) # Add modeling imports here +from ..detr.modeling_detr import ( + DetrForObjectDetection, + DetrModel, +) from ..convbert.modeling_convbert import ( ConvBertForMaskedLM, ConvBertForMultipleChoice, @@ -258,6 +266,7 @@ XLNetModel, ) from .configuration_auto import ( + DetrConfig, AlbertConfig, AutoConfig, BartConfig, @@ -313,6 +322,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping + (DetrConfig, DetrModel), (Wav2Vec2Config, Wav2Vec2Model), (ConvBertConfig, ConvBertModel), (LEDConfig, LEDModel), @@ -396,6 +406,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping + (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (ConvBertConfig, ConvBertForMaskedLM), (LEDConfig, LEDForConditionalGeneration), @@ -495,6 +506,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), (MT5Config, MT5ForConditionalGeneration), diff --git a/src/transformers/models/detr/__init__.py b/src/transformers/models/detr/__init__.py new file mode 100644 index 00000000000000..182650212c9361 --- /dev/null +++ b/src/transformers/models/detr/__init__.py @@ -0,0 +1,71 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING +from ...file_utils import _BaseLazyModule, is_torch_available, is_tokenizers_available +_import_structure = { + "configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"], + "tokenization_detr": ["DetrTokenizer"], +} + +if is_tokenizers_available(): + _import_structure["tokenization_detr_fast"] = ["DetrTokenizerFast"] + +if is_torch_available(): + _import_structure["modeling_detr"] = [ + "DETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "DetrForObjectDetection", + "DetrModel", + "DetrPreTrainedModel", + ] + + + + +if TYPE_CHECKING: + from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig + from .tokenization_detr import DetrTokenizer + + if is_tokenizers_available(): + from .tokenization_detr_fast import DetrTokenizerFast + + if is_torch_available(): + from .modeling_detr import ( + DETR_PRETRAINED_MODEL_ARCHIVE_LIST, + DetrForObjectDetection, + DetrModel, + DetrPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py new file mode 100644 index 00000000000000..be50649ca601e4 --- /dev/null +++ b/src/transformers/models/detr/configuration_detr.py @@ -0,0 +1,225 @@ +# coding=utf-8 +# Copyright Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DETR model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DETR_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/detr-resnet-50": "https://huggingface.co/facebook/detr-resnet-50/resolve/main/config.json", + # See all DETR models at https://huggingface.co/models?filter=detr +} + + +class DetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.DetrModel`. + It is used to instantiate a DETR model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the DETR `facebook/detr-resnet-50 `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + num_queries (:obj:`int`, `optional`, defaults to 100): + Number of object queries, i.e. detection slots. This is the maximal number of objects + :class:`~transformers.DetrModel` can detect in a single image. For COCO, we recommend 100 queries. + d_model (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the layers. + encoder_layers (:obj:`int`, `optional`, defaults to 6): + Number of encoder layers. + decoder_layers (:obj:`int`, `optional`, defaults to 6): + Number of decoder layers. + encoder_attention_heads (:obj:`int`, `optional`, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (:obj:`int`, `optional`, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (:obj:`int`, `optional`, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. + decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + auxiliary_loss (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`sine`): + Type of position embeddings to be used on top of the image features. One of 'sine' or 'learned'. + backbone (:obj:`bool`, `optional`, defaults to :obj:`resnet50`): + Name of convolutional backbone to use. Currently only resnet of the Torchvision package is supported. + train_backbone (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to train (fine-tune) the backbone. + dilation (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to replace stride with dilation in the last convolutional block (DC5). + masks (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to train the segmentation head. + class_cost (:obj:`float`, `optional`, defaults to 1): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (:obj:`float`, `optional`, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (:obj:`float`, `optional`, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + mask_loss_coefficient (:obj:`float`, `optional`, defaults to 1): + Relative weight of the Focal loss in the panoptic segmentation loss. + dice_loss_coefficient (:obj:`float`, `optional`, defaults to 1): + Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. + bbox_loss_coefficient (:obj:`float`, `optional`, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (:obj:`float`, `optional`, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (:obj:`float`, `optional`, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + + Examples:: + + >>> from transformers import DetrModel, DetrConfig + + >>> # Initializing a DETR facebook/detr-resnet-50 style configuration + >>> configuration = DetrConfig() + + >>> # Initializing a model from the facebook/detr-resnet-50 style configuration + >>> model = DetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "detr" + keys_to_ignore_at_inference = ["past_key_values"] + def __init__( + self, + num_queries=100, + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + scale_embedding=False, + gradient_checkpointing=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + auxiliary_loss=False, + position_embedding_type='sine', + backbone='resnet50', + train_backbone=True, + dilation=False, + masks=False, + class_cost=1, + bbox_cost=5, + giou_cost=2, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs + ) + + self.num_queries = num_queries + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.gradient_checkpointing = gradient_checkpointing + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + self.backbone = backbone + self.train_backbone = train_backbone + self.dilation = dilation + self.masks = masks + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.eos_coefficient = eos_coefficient + + + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model \ No newline at end of file diff --git a/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000000..18fea997ae15bb --- /dev/null +++ b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DETR checkpoints.""" + + +import argparse +from pathlib import Path + +import torch +import torchvision.transforms as T +from packaging import version + +from PIL import Image +import requests + +from transformers import ( + DetrConfig, + DetrModel, + DetrForObjectDetection, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# here we list all keys to be renamed (original name on the left, our name on the right) +rename_keys = [] +for i in range(6): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append(("transformer.encoder.layers." + str(i) + ".self_attn.out_proj.weight", "encoder.layers." + str(i) + ".self_attn.out_proj.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".self_attn.out_proj.bias", "encoder.layers." + str(i) + ".self_attn.out_proj.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear1.weight", "encoder.layers." + str(i) + ".fc1.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear1.bias", "encoder.layers." + str(i) + ".fc1.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear2.weight", "encoder.layers." + str(i) + ".fc2.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear2.bias", "encoder.layers." + str(i) + ".fc2.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm1.weight", "encoder.layers." + str(i) + ".self_attn_layer_norm.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm1.bias", "encoder.layers." + str(i) + ".self_attn_layer_norm.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm2.weight", "encoder.layers." + str(i) + ".final_layer_norm.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm2.bias", "encoder.layers." + str(i) + ".final_layer_norm.bias")) + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append(("transformer.decoder.layers." + str(i) + ".self_attn.out_proj.weight", "decoder.layers." + str(i) + ".self_attn.out_proj.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".self_attn.out_proj.bias","decoder.layers." + str(i) + ".self_attn.out_proj.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".multihead_attn.out_proj.weight", "decoder.layers." + str(i) + ".encoder_attn.out_proj.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".multihead_attn.out_proj.bias", "decoder.layers." + str(i) + ".encoder_attn.out_proj.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear1.weight", "decoder.layers." + str(i) + ".fc1.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear1.bias", "decoder.layers." + str(i) + ".fc1.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear2.weight", "decoder.layers." + str(i) + ".fc2.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear2.bias", "decoder.layers." + str(i) + ".fc2.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm1.weight", "decoder.layers." + str(i) + ".self_attn_layer_norm.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm1.bias", "decoder.layers." + str(i) + ".self_attn_layer_norm.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm2.weight", "decoder.layers." + str(i) + ".encoder_attn_layer_norm.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm2.bias", "decoder.layers." + str(i) + ".encoder_attn_layer_norm.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm3.weight", "decoder.layers." + str(i) + ".final_layer_norm.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm3.bias", "decoder.layers." + str(i) + ".final_layer_norm.bias")) + + +# convolutional projection + query embeddings + layernorm of decoder +rename_keys.extend([("input_proj.weight", "input_projection.weight"), +("input_proj.bias", "input_projection.bias"), +("query_embed.weight", "query_position_embeddings.weight"), +("transformer.decoder.norm.weight", "decoder.layernorm.weight"), +("transformer.decoder.norm.bias", "decoder.layernorm.bias")]) + + +def remove_object_detection_heads_(state_dict): + ignore_keys = [ + "class_embed.weight", + "class_embed.bias", + "bbox_embed.layers.0.weight", + "bbox_embed.layers.0.bias", + "bbox_embed.layers.1.weight", + "bbox_embed.layers.1.bias", + "bbox_embed.layers.2.weight", + "bbox_embed.layers.2.bias" + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(state_dict, old, new): + val = state_dict.pop(old) + state_dict[new] = val + + +def read_in_q_k_v(state_dict): + # first: transformer encoder + for i in range(6): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop("transformer.encoder.layers." + str(i) + ".self_attn.in_proj_weight") + in_proj_bias = state_dict.pop("transformer.encoder.layers." + str(i) + ".self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict["encoder.layers." + str(i) + ".self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict["encoder.layers." + str(i) + ".self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict["encoder.layers." + str(i) + ".self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict["encoder.layers." + str(i) + ".self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict["encoder.layers." + str(i) + ".self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict["encoder.layers." + str(i) + ".self_attn.v_proj.bias"] = in_proj_bias[-256:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(6): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop("transformer.decoder.layers." + str(i) + ".self_attn.in_proj_weight") + in_proj_bias = state_dict.pop("transformer.decoder.layers." + str(i) + ".self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict["decoder.layers." + str(i) + ".self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict["decoder.layers." + str(i) + ".self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict["decoder.layers." + str(i) + ".self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict["decoder.layers." + str(i) + ".self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict["decoder.layers." + str(i) + ".self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict["decoder.layers." + str(i) + ".self_attn.v_proj.bias"] = in_proj_bias[-256:] + # read in weights + bias of input projection layer of cross-attention + in_proj_weight_cross_attn = state_dict.pop("transformer.decoder.layers." + str(i) + ".multihead_attn.in_proj_weight") + in_proj_bias_cross_attn = state_dict.pop("transformer.decoder.layers." + str(i) + ".multihead_attn.in_proj_bias") + # next, add query, keys and values (in that order) of cross-attention to the state dict + state_dict["decoder.layers." + str(i) + ".encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :] + state_dict["decoder.layers." + str(i) + ".encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256] + state_dict["decoder.layers." + str(i) + ".encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :] + state_dict["decoder.layers." + str(i) + ".encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512] + state_dict["decoder.layers." + str(i) + ".encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :] + state_dict["decoder.layers." + str(i) + ".encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:] + + +# since we renamed the classification heads of the object detection model, we need to rename the original keys: +rename_keys_object_detection_model = [ +("class_embed.weight", "class_labels_classifier.weight"), +("class_embed.bias", "class_labels_classifier.bias"), +("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"), +("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"), +("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"), +("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"), +("bbox_embed.layers.2.weight","bbox_predictor.layers.2.weight"), +("bbox_embed.layers.2.bias","bbox_predictor.layers.2.bias"), +] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + im = Image.open(requests.get(url, stream=True).raw) + + # standard PyTorch mean-std input image normalization + transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0) + + return img + + +# COCO classes +CLASSES = [ + 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', + 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', + 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', + 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', + 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', + 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', + 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', + 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', + 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', + 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', + 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', + 'toothbrush' +] + + +@torch.no_grad() +def convert_detr_checkpoint(task, backbone, dilation, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our DETR structure. + """ + + config = DetrConfig() + img = prepare_img() + + logger.info(f"Converting model for task {task}, with a {backbone} backbone, dilation set to {dilation}...") + + if task == "base_model": + # load model from torch hub + detr = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).eval() + state_dict = detr.state_dict() + # rename keys + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # remove classification heads + remove_object_detection_heads_(state_dict) + # finally, create model and load state dict + model = DetrModel(config).eval() + model.load_state_dict(state_dict) + # verify our conversion on the image + outputs = model(img) + assert outputs.last_hidden_state.shape == (1, config.num_queries, config.d_model) + expected_slice = torch.tensor([[0.0616, -0.5146, -0.4032], + [-0.7629, -0.4934, -1.7153], + [-0.4768, -0.6403, -0.7826]]) + assert torch.allclose(outputs.last_hidden_state[0,:3,:3], expected_slice, atol=1e-4) + + elif task == "object_detection": + # coco has 91 labels + config.num_labels = 91 + config.id2label = {v: k for v, k in enumerate(CLASSES)} + config.label2id = {k: v for v, k in enumerate(CLASSES)} + # load model from torch hub + if backbone == 'resnet_50' and not dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).eval() + elif backbone == 'resnet_50' and dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_dc5_resnet50', pretrained=True).eval() + config.dilation = True + elif backbone == 'resnet_101' and not dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True).eval() + config.backbone = 'resnet_101' + elif backbone == 'resnet_101' and dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_dc5_resnet101', pretrained=True).eval() + config.backbone = 'resnet_101' + config.dilation = True + else: + raise ValueError(f"Not supported: {backbone} with {dilation}") + + state_dict = detr.state_dict() + # rename keys + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # rename classification heads + for src, dest in rename_keys_object_detection_model: + rename_key(state_dict, src, dest) + # important: we need to prepend "model." to each of the base model keys as DetrForObjectDetection calls the base model like this + for key in state_dict.copy().keys(): + if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"): + val = state_dict.pop(key) + state_dict["model." + key] = val + # finally, create model and load state dict + model = DetrForObjectDetection(config).eval() + model.load_state_dict(state_dict) + # verify our conversion + original_outputs = detr(img) + outputs = model(img) + assert torch.allclose(outputs.pred_logits, original_outputs['pred_logits'], atol=1e-4) + assert torch.allclose(outputs.pred_boxes, original_outputs['pred_boxes'], atol=1e-4) + + elif task == "panoptic_segmentation": + # First, load in original detr from torch hub + if backbone == 'resnet_50' and not dilation: + detr, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', + pretrained=True, return_postprocessor=True, num_classes=250) + detr.eval() + elif backbone == 'resnet_50' and dilation: + detr, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_dc5_resnet50_panoptic', + pretrained=True, return_postprocessor=True, num_classes=250) + detr.eval() + config.dilation = True + elif backbone == 'resnet_101' and not dilation: + detr, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_resnet101_panoptic', + pretrained=True, return_postprocessor=True, num_classes=250) + detr.eval() + config.backbone = 'resnet_101' + else: + print("Not supported:", backbone, dilation) + + else: + print("Task not in list of supported tasks:", task) + + # Save model + logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--task", default='base_model', type=str, help="""Task for which to convert a checkpoint. One of 'base_model', + 'object_detection' or 'panoptic_segmentation'.""") + parser.add_argument("--backbone", default='resnet_50', type=str, help="Which backbone to use. One of 'resnet50', 'resnet101'.") + parser.add_argument("--dilation", default=False, action="store_true", help="Whether to apply dilated convolution.") + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_detr_checkpoint(args.task, args.backbone, args.dilation, args.pytorch_dump_folder_path) \ No newline at end of file diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py new file mode 100644 index 00000000000000..5fca5d9072a093 --- /dev/null +++ b/src/transformers/models/detr/modeling_detr.py @@ -0,0 +1,2000 @@ +# coding=utf-8 +# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DETR model. """ + + +import math +import random +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, List, Union +from scipy.optimize import linear_sum_assignment + +import torch +import torch.nn.functional as F +import torchvision +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.boxes import box_area +from torch import nn +from torch.nn import CrossEntropyLoss +from torch import Tensor +import torch.distributed as dist + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithCrossAttentions, + # BaseModelOutputWithPastAndCrossAttentions, (Niels): don't think we need this one as DETR uses parallel decoding + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_detr import DetrConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DetrConfig" +_TOKENIZER_FOR_DOC = "DetrTokenizer" + + +DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/detr-resnet-50", + # See all DETR models at https://huggingface.co/models?filter=detr +] + + +@dataclass +class BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates(BaseModelOutputWithCrossAttentions): + """ + This class adds one attribute to BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder + activations, i.e. the output of each decoder layer, each of them gone through a layernorm. + Args: + intermediate_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(config.decoder_layers, batch_size, sequence_length, hidden_size)`): + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + +@dataclass +class Seq2SeqModelOutputWithIntermediateHiddenStates(Seq2SeqModelOutput): + """ + This class adds one attribute to Seq2SeqModelOutput, namely an optional stack of intermediate decoder + activations, i.e. the output of each decoder layer, each of them gone through a layernorm. + Args: + intermediate_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(config.decoder_layers, batch_size, sequence_length, hidden_size)`): + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + +@dataclass +class DetrObjectDetectionOutput(ModelOutput): + """ + Output type of :class:`~transformers.DetrForObjectDetection`. + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` are provided)): + Total loss as the sum of (...). + pred_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the + unnormalized bounding box. + auxiliary_outputs (:obj:`list[Dict]`, `optional`): + Optional, only returned when auxilary losses are activated (i.e. config.auxiliary_loss is set to True) and labels are provided. It is a + list of dictionnaries containing the two above keys (pred_logits and pred_boxes) for each decoder layer. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 tensors + of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + pred_logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + auxiliary_outputs: Optional[List[Dict]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +## BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/misc.py + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + """ + Data type that handles different types of inputs (either list of images or list of sequences), + and computes the padded output (with masking). + """ + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('Not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +## BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/master/backbone.py + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the IntermediateLayerGetter + output_dict = self.body(pixel_values) + # currently there's no support for intermediate layers of the backbone + feature_map = output_dict["0"] + assert pixel_mask is not None + # we downsample the pixel_mask to match the feature map + mask = F.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + return feature_map, mask + + # # this one should be removed in the future + # def forward(self, tensor_list: NestedTensor): + # xs = self.body(tensor_list.tensors) + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = tensor_list.mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=True, norm_layer=FrozenBatchNorm2d) + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, pixel_values, pixel_mask): + # first, send pixel_values and pixel_mask through backbone to obtain updated feature_map and pixel_mask + feature_map, pixel_mask = self[0](pixel_values, pixel_mask) + + # next, create position embeddings for the outputs of the outputs_dict + pos = self[1](feature_map, pixel_mask).to(feature_map.dtype) + + return feature_map, pixel_mask, pos + + # def forward(self, tensor_list: NestedTensor): + # xs = self[0](tensor_list) + # print("The backbone outputs a NestedTensor with a Tensor and a Mask.") + # print("Shape of backbone tensor:") + # print(xs["0"].tensors.shape) + # print("Shape of backbone mask:") + # print(xs["0"].mask.shape) + # out: List[NestedTensor] = [] + # pos = [] + # for name, x in xs.items(): + # out.append(x) + # # position encoding + # pos.append(self[1](x).to(x.tensors.dtype)) + + # print(len(pos)) + # print("Shape of position embeddings:") + # print(pos[0].shape) + + # return out, pos + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class DetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, pixel_values, pixel_mask): + x = pixel_values + mask = pixel_mask + assert mask is not None + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class DetrLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embeddings.weight) + nn.init.uniform_(self.column_embeddings.weight) + + def forward(self, pixel_values, pixel_mask=None): + x = pixel_values + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.column_embeddings(i) + y_emb = self.row_embeddings(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(config): + N_steps = config.d_model // 2 + if config.position_embedding_type == 'sine': + # TODO find a better way of exposing other arguments + position_embedding = DetrSinePositionEmbedding(N_steps, normalize=True) + elif config.position_embedding_type == 'learned': + position_embedding = DetrLearnedPositionEmbedding(N_steps) + else: + raise ValueError(f"not supported {config.position_embedding_type}") + + return position_embedding + + +# class DetrLearnedPositionalEmbedding(nn.Embedding): +# """ +# This module learns positional embeddings up to a fixed maximum size. +# """ + +# def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): +# assert padding_idx is not None, "`padding_idx` should not be None, but of type int" +# super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + +# def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): +# """`input_ids_shape` is expected to be [bsz x seqlen].""" +# bsz, seq_len = input_ids_shape[:2] +# positions = torch.arange( +# past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device +# ) +# return super().forward(positions) + + +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + # added (Niels) + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + key_value_position_embeddings: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # Added (Niels): add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # Added (Niels): add key-value position embeddings to the key value states + if key_value_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class DetrEncoderLayer(nn.Module): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor = None, + output_attentions: bool = False): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_embeddings (:obj:`torch.FloatTensor`, `optional`): position embeddings, to be added to hidden_states. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, + output_attentions=output_attentions + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class DetrDecoderLayer(nn.Module): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = DetrAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + key_value_position_embeddings=position_embeddings, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class DetrClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class DetrPreTrainedModel(PreTrainedModel): + config_class = DetrConfig + base_model_prefix = "model" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +DETR_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.DetrConfig`): + Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +DETR_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. + + Pixel values can be obtained using :class:`~transformers.DetrTokenizer`. See + :meth:`transformers.DetrTokenizer.__call__` for details. + + pixel_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding pixel values. Mask values selected in ``[0, 1]``: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Not used by default. + + If you want to change padding behavior, you should read :func:`modeling_detr._prepare_decoder_inputs` and + modify to your needs. See diagram 1 in `the paper `__ for more + information on the default strategy. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class DetrEncoder(DetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + :class:`DetrEncoderLayer`. + + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweaks for DETR: + - position_embeddings are added to the forward pass. + + Args: + config: DetrConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: DetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + # if embed_tokens is not None: + # self.embed_tokens = embed_tokens + # else: + # self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + # self.embed_positions = DetrLearnedPositionalEmbedding( + # config.max_position_embeddings, + # embed_dim, + # self.padding_idx, + # ) + + self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + # (Niels) in the original DETR, no layernorm is used for the Encoder, as "normalize_before" is set to False + + self.init_weights() + + def forward( + self, + #input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding pixel features. Mask values selected in ``[0, 1]``: + + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). + + `What are attention masks? <../glossary.html#attention-mask>`__ + + position_embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + # elif input_ids is not None: + # input_shape = input_ids.size() + # input_ids = input_ids.view(-1, input_shape[-1]) + # elif inputs_embeds is not None: + # input_shape = inputs_embeds.size()[:-1] + # else: + # raise ValueError("You have to specify either input_ids or inputs_embeds") + + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed_pos = self.embed_positions(input_shape) + + # # add position embeddings + # hidden_states = inputs_embeds + embed_pos + # hidden_states = self.layernorm_embedding(hidden_states) + # hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = inputs_embeds + # (Niels) comment out layernorm, see __init__ above + #hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + ) + else: + # we add position_embeddings as extra input to the encoder_layer + layer_outputs = encoder_layer(hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class DetrDecoder(DetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`DetrDecoderLayer`. + + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + + Some small tweaks for DETR: + - position_embeddings and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: DetrConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: DetrConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + # don't think we need embed_tokens (output tokens) here, since we are just updating the query embeddings + + # if embed_tokens is not None: + # self.embed_tokens = embed_tokens + # else: + # self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + # self.embed_positions = DetrLearnedPositionalEmbedding( + # config.max_position_embeddings, + # config.d_model, + # self.padding_idx, + # ) + self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in DETR, the decoder uses layernorm after the last decoder layer output + self.layernorm = nn.LayerNorm(config.d_model) + + self.init_weights() + + def forward( + self, + #input_ids=None, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + position_embeddings=None, + query_position_embeddings=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + The query embeddings that are passed into the decoder. + + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on certain queries. Mask values selected in ``[0, 1]``: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values + selected in ``[0, 1]``: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + `What are attention masks? <../glossary.html#attention-mask>`__ + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # (Niels) following lines are not required as DETR doesn't use input_ids and inputs_embeds + # retrieve input_ids and inputs_embeds + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + # elif input_ids is not None: + # input_shape = input_ids.size() + # input_ids = input_ids.view(-1, input_shape[-1]) + # elif inputs_embeds is not None: + # input_shape = inputs_embeds.size()[:-1] + # else: + # raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + # to do: should be updated, because no input_ids here + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # added this (Niels) to infer input_shape: + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + combined_attention_mask = None + # (Niels): following lines are not required as DETR uses parallel decoding instead of autoregressive + # # create causal mask + # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # combined_attention_mask = None + + # if input_shape[-1] > 1: + # combined_attention_mask = _make_causal_mask( + # input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + # ).to(self.device) + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # (Niels): following lines are not required because adding position embeddings happens in DetrDecoderLayer + # embed positions + # positions = self.embed_positions(input_shape, past_key_values_length) + + # hidden_states = inputs_embeds + positions + # hidden_states = self.layernorm_embedding(inputs_embeds) + # hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # (Niels): added an optional list: + intermediate = [] if self.config.auxiliary_loss else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + next_decoder_cache = () if use_cache else None + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if self.config.auxiliary_loss: + intermediate.append(self.layernorm(hidden_states)) + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) + + # finally, apply layernorm + hidden_states = self.layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # stack intermediate decoder activations + if self.config.auxiliary_loss: + intermediate = torch.stack(intermediate) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions, + intermediate] + if v is not None + ) + return BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates( + last_hidden_state=hidden_states, + #past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=intermediate, + ) + + +@add_start_docstrings( + """The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without + any specific head on top.""", + DETR_START_DOCSTRING, +) +class DetrModel(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = Backbone(config.backbone, config.train_backbone, config.masks, config.dilation) + position_embeddings = build_position_encoding(config) + self.backbone = Joiner(backbone, position_embeddings) + + # Create projection layer + self.input_projection = nn.Conv2d(backbone.num_channels, config.d_model, kernel_size=1) + + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = DetrEncoder(config) + self.decoder = DetrDecoder(config) + + self.init_weights() + + # def get_input_embeddings(self): + # return self.shared + + # def set_input_embeddings(self, value): + # self.shared = value + # self.encoder.embed_tokens = self.shared + # self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/detr-resnet-50", + # output_type=Seq2SeqModelOutput, + # config_class=_CONFIG_FOR_DOC, + # ) + def forward( + self, + #samples: NestedTensor=None, + pixel_values, + pixel_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # (includes features map, downsampled mask and position embeddings) + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + feature_map, mask, position_embeddings = self.backbone(pixel_values, pixel_mask) + + assert mask is not None + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + projected_feature_map = self.input_projection(feature_map) + + # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + position_embeddings = position_embeddings.flatten(2).permute(0, 2, 1) + + flattened_mask = mask.flatten(1) + + # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder + # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, heigth*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=flattened_features, + attention_mask=flattened_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + queries = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=queries, + attention_mask=None, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=flattened_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutputWithIntermediateHiddenStates( + last_hidden_state=decoder_outputs.last_hidden_state, + #past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, + ) + + +@add_start_docstrings( + """DETR Model (consisting of a backbone and encoder-decoder Transformer) with an object detection head on top, + for tasks such as COCO.""", + DETR_START_DOCSTRING, +) +class DetrForObjectDetection(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # DETR base model + self.model = DetrModel(config) + + # Object detection heads + self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels + 1) + self.bbox_predictor = MLP(input_dim=config.d_model, hidden_dim=config.d_model, + output_dim=4, num_layers=3) + + self.init_weights() + + # copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/detr-resnet-50", + # output_type=DetrObjectDetectionOutput, + # config_class=_CONFIG_FOR_DOC, + # ) + def forward( + self, + #samples: NestedTensor=None, + pixel_values, + pixel_mask, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`List[Dict]` of len :obj:`(batch_size,)`, `optional`): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing 2 keys: 'class_labels' and + 'boxes' (the class labels and bounding boxes of an image in the batch respectively). The class labels themselves should + be a :obj:`torch.LongTensor` of len :obj:`(number of bounding boxes in the image,)` and the boxes a :obj:`torch.FloatTensor` + of shape :obj:`(number of bounding boxes in the image, 4)`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through DETR base model to obtain encoder + decoder outputs + outputs = self.model(pixel_values, pixel_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, + #past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict,) + + # class logits + predicted bounding boxes + # to do: make this as efficient as the original implementation + pred_logits = self.class_labels_classifier(outputs.last_hidden_state) + pred_boxes = self.bbox_predictor(outputs.last_hidden_state).sigmoid() + + loss, auxiliary_outputs = None, None + if labels is not None: + # First: create the matcher + matcher = HungarianMatcher(class_cost=self.config.class_cost, + bbox_cost=self.config.bbox_cost, + giou_cost=self.config.giou_cost) + # Second: create the criterion + weight_dict = {'loss_ce': 1, 'loss_bbox': self.config.bbox_loss_coefficient} + weight_dict['loss_giou'] = self.config.giou_loss_coefficient + # to do: move the following three lines to DetrForPanopticSegmentation + if self.config.masks: + weight_dict["loss_mask"] = self.config.mask_loss_coef + weight_dict["loss_dice"] = self.config.dice_loss_coef + # TODO this is a hack (doesn't work yet) + if self.config.auxiliary_loss: + aux_weight_dict = {} + for i in range(self.config.decoder_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + # to do: move the following two lines to DetrForPanopticSegmentation + if self.config.masks: + losses += ["masks"] + # (copied from original repo in detr.py): + # the naming of the `num_classes` parameter of the criterion is somewhat misleading. + # it indeed corresponds to `max_obj_id + 1`, where max_obj_id + # is the maximum id for a class in your dataset. For example, + # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. + # As another example, for a dataset that has a single class with id 1, + # you should pass `num_classes` to be 2 (max_obj_id + 1). + # For more details on this, check the following discussion + # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 + criterion = SetCriterion(matcher=matcher, num_classes=self.config.num_labels, + weight_dict=weight_dict, eos_coef=self.config.eos_coefficient, losses=losses) + criterion.to(self.device) + # Third: compute the loss, based on outputs and labels + outputs = {} + outputs['pred_logits'] = pred_logits + outputs['pred_boxes'] = pred_boxes + if self.config.auxiliary_loss: + intermediate = outputs.intermediate_hidden_states if return_dict else outputs[6] + outputs_class = self.class_labels_classifier(intermediate) + outputs_coord = self.bbox_predictor(intermediate).sigmoid() + auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) + outputs['auxiliary_outputs'] = auxiliary_outputs + + loss_dict = criterion(outputs, labels) + weight_dict = criterion.weight_dict + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + if not return_dict: + # to be verified + if auxiliary_outputs is not None: + output = (pred_logits, pred_boxes) + auxiliary_outputs + decoder_outputs + encoder_outputs + else: + output = (pred_logits, pred_boxes) + decoder_outputs + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return DetrObjectDetectionOutput( + loss=loss, + pred_logits=pred_logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + #past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +# copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class SetCriterion(nn.Module): + """ This class computes the loss for DETRForObjectDetection. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, matcher, num_classes, weight_dict, eos_coef, losses): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals. + num_classes: number of object categories, omitting the special no-object category. + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category. + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + # (Niels): set log to False because we don't want to include accuracy in the modeling file + def loss_labels(self, outputs, targets, indices, num_boxes, log=False): + """Classification loss (NLL) + targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes), + box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'auxiliary_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + # (Niels): comment out function below, distributed training to be added + # if is_dist_avail_and_initialized(): + # torch.distributed.all_reduce(num_boxes) + # (Niels) in original implementation, num_boxes is divided by get_world_size() + num_boxes = torch.clamp(num_boxes, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'auxiliary_outputs' in outputs: + for i, auxiliary_outputs in enumerate(outputs['auxiliary_outputs']): + indices = self.matcher(auxiliary_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +# copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class MLP(nn.Module): + """ + Very simple multi-layer perceptron (also called FFN), used to predict the normalized + center coordinates, height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +# copied from https://github.com/facebookresearch/detr/blob/master/models/matcher.py +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): + """Creates the matcher. + Params: + class_cost: This is the relative weight of the classification error in the matching cost + bbox_cost: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + giou_cost: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + assert class_cost != 0 or bbox_cost != 0 or giou_cost != 0, "All costs of the Matcher can't be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching. + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs['pred_logits'].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs['pred_logits'].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs['pred_boxes'].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["class_labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + class_cost = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost between boxes + giou_cost = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +# below: functies copied from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + The boxes should be in [x0, y0, x1, y1] format + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area \ No newline at end of file diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py new file mode 100644 index 00000000000000..26b618ce9d560a --- /dev/null +++ b/src/transformers/models/detr/tokenization_detr.py @@ -0,0 +1,591 @@ +# coding=utf-8 +# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for DETR.""" + +from typing import Dict, List, Optional, Tuple, Union +from pathlib import Path +import json +from collections import defaultdict +import time + +import numpy as np +import torch +from torch import Tensor +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms as T +import PIL + +from ...file_utils import add_end_docstrings +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +DETR_KWARGS_DOCSTRING = r""" + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. +""" + +## BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/misc.py + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + """ + Data type that handles different types of inputs (either list of images or list of sequences), + and computes the padded output (with masking). + """ + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: Union[List[Tensor], torch.Tensor]): + # TODO make this more n + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.zeros((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = True + else: + raise ValueError('Not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +# Note: inverting mask has not yet been taken into account +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +## Image + target transformations for object detection +## Taken from https://github.com/facebookresearch/detr/blob/master/datasets/transforms.py + +# this: extra transform, based on https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/datasets/coco.py#L21 +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False): + self.return_masks = return_masks + + def __call__(self, image, target): + w, h = image.size + + #image_id = target["image_id"] + #image_id = torch.tensor([image_id]) + + #anno = target["annotations"] + + #anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + #boxes = [obj["bbox"] for obj in anno] + if target is not None: + boxes = target["boxes"] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + #classes = [obj["category_id"] for obj in anno] + #classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + # keypoints = None + # if anno and "keypoints" in anno[0]: + # keypoints = [obj["keypoints"] for obj in anno] + # keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + # num_keypoints = keypoints.shape[0] + # if num_keypoints: + # keypoints = keypoints.view(num_keypoints, -1, 3) + + # keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + # boxes = boxes[keep] + # classes = classes[keep] + # if self.return_masks: + # masks = masks[keep] + # if keypoints is not None: + # keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + # target["labels"] = classes + # if self.return_masks: + # target["masks"] = masks + # target["image_id"] = image_id + # if keypoints is not None: + # target["keypoints"] = keypoints + + # # for conversion to coco api + # area = torch.tensor([obj["area"] for obj in anno]) + # iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + # target["area"] = area[keep] + # target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + return image, target + return image, None + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +class Resize(object): + def __init__(self, size, max_size=None): + self.size = size + self.max_size = max_size + + def __call__(self, image, target=None): + return resize(image, target, self.size, self.max_size) + +# copied from https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/box_ops.py +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target=None): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class DetrTokenizer(PreTrainedTokenizer): + """ + Constructs a DETR tokenizer. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains some of the main methods. + Users should refer to the superclass for more information regarding such methods. + + Args: + bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The beginning of sentence token. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sentence token. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`): + The token used for defining the end of a word. + **kwargs + Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer` + """ + + model_input_names = ["input_values"] + + def __init__( + self, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + word_delimiter_token="|", + do_lower_case=False, + **kwargs + ): + super().__init__( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + do_lower_case=do_lower_case, + word_delimiter_token=word_delimiter_token, + **kwargs, + ) + self._word_delimiter_token = word_delimiter_token + self.do_lower_case = do_lower_case + + # load dataset + self.dataset ,self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict() + self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) + + @add_end_docstrings(DETR_KWARGS_DOCSTRING) + def __call__( + self, + images: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor], str], + annotations: Optional[Union[Dict, List[Dict], str]] = None, + padding: Union[bool, str] = True, + return_mask: Union[bool, str] = True, + resize: Optional[bool] = True, + size: Optional[int] = 800, + max_size: Optional[int] = 1333, + normalize: Optional[bool] = True, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to prepare for the model one or several image(s). This includes resizing, normalization and padding up to + the largest image in a batch while creating a pixel mask for each image indicating which pixels are real and which are padding. + + Args: + images (:obj:`PIL.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`, :obj:`str`): + The image or batch of images to be prepared. Each image can be a PIL image, numpy array or a Torch tensor. You can also provide the name of the path + to a directory containing the images. + annotations (:obj:`Dict`, :obj:`List[Dict]`): + The annotations as either a Python dictionary (in case of a single image) or a list of Python dictionaries (in case + of a batch of images). Each dictionary should include the following keys: + + - boxes (:obj:`List[List[float]]`): the coordinates of the N bounding boxes in [x0, y0, x1, y1] format: the x and y coordinate of the top left and the height and width. + - labels (:obj:`List[int]`): the label for each bounding box. 0 represents the background (i.e. 'no object') class. + - (optionally) masks (:obj:`List[List[float]]`): the segmentation masks for each of the objects. + - (optionally) keypoints (FloatTensor[N, K, 3]): for each one of the N objects, it contains the K keypoints in [x, y, visibility] format, defining the object. visibility=0 means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adapt references/detection/transforms.py for your new keypoint representation. + + Instead, you can also provide a path to a json file in COCO format. + + resize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to resize images (and annotations) to a certain :obj:`size`. + size (:obj:`int`, `optional`, defaults to :obj:`800`): + Resize the input image to the given size. Only has an effect if :obj:`resize` is set to :obj:`True`. + max_size (:obj:`int`, `optional`, defaults to :obj:`1333`): + The largest size an image dimension can have (otherwise it's capped). In COCO, the authors of DETR used a :obj:`max_size` of 1333. + normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to apply standard ImageNet mean/std normalization of images. + """ + + # images can be a path to a directory, and annotations can be a path to a json file in COCO format + # the following is based on the COCO class of https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py + if isinstance(images, str) and isinstance(annotations, str): + with open(annotations) as f: + if verbose: + logger.info('Loading annotations into memory...') + tic = time.time() + dataset = json.load(open(annotations, 'r')) + assert type(dataset)==dict, 'Annotation file format {} not supported'.format(type(dataset)) + if verbose: + logger.info('Done (t={:0.2f}s)'.format(time.time()- tic)) + + imgToAnns = defaultdict(list) + anns, imgs = {} + if verbose: + # create index + logger.info('Creating index...') + + if 'images' in self.dataset: + for img in self.dataset['images']: + imgs[img['id']] = img + + if 'annotations' in self.dataset: + for ann in self.dataset['annotations']: + imgToAnns[ann['image_id']].append(ann) + anns[ann['id']] = ann + + if verbose: + logger.info('Index created!') + + # next, turn into list of dicts (each dict should correspond to an image in 'images') + # this is based on https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/datasets/coco.py#L50 + images_path = Path(images).glob('*/*.jpg') + images = [] + annotations = [] + for image_file_name in images_path: + if image_file_name.is_file(): + images.append(Image.open(image_file_name).convert("RGB")) + annotation_dict = {} + annotation_dict['boxes'] = [obj["bbox"] for obj in imgToAnns[image_file_name]] + annotation_dict['area'] = [obj["area"] for obj in imgToAnns[image_file_name]] + annotation_dict['classes'] = [obj["category_id"] for obj in imgToAnns[image_file_name]] + annotations.append(annotation_dict) + + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (PIL.Image.Image, np.ndarray, torch.Tensor))) + ) + + # step 1: make images a list of PIL images no matter what + if is_batched: + if isinstance(images[0], np.ndarray): + images = [Image.fromarray(image) for image in images] + elif isinstance(images[0], torch.Tensor): + images = [T.ToPILImage()(image).convert("RGB") for image in images] + if annotations is not None: + assert len(images) == len(annotations) + else: + if isinstance(images, PIL.Image.Image): + images = [images] + if annotations is not None: + annotations = [annotations] + + # step 2: define transformations (resizing + normalization) + transformations = [ConvertCocoPolysToMask(),] + if resize and size is not None: + transformations.append(Resize(size=size, max_size=max_size)) + if normalize: + normalization = Compose([ + ToTensor(), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + transformations.append(normalization) + transforms = Compose(transformations) + + # step 3: apply transformations to both images and annotations + transformed_images = [] + transformed_annotations = [] + if annotations is not None: + for image, annotation in zip(images, annotations): + image, annotation = transforms(image, annotation) + transformed_images.append(image) + transformed_annotations.append(annotation) + else: + transformed_images = [transforms(image, None)[0] for image in images] + + # step 4: create NestedTensor which takes care of padding the pixels up to biggest image + # and creation of mask. We don't need the transformed targets for this + samples = nested_tensor_from_tensor_list(transformed_images) + + # return as BatchEncoding + data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask} + + if annotations is not None: + data['labels'] = transformed_annotations + + encoded_inputs = BatchEncoding(data=data) + + return encoded_inputs + + @property + def vocab_size(self) -> int: + return len(self.decoder) + + def get_vocab(self) -> Dict: + return dict(self.encoder, **self.added_tokens_encoder) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """ + Converts a connectionist-temporal-classification (CTC) output tokens into a single string. + """ + # group same tokens into non-repeating tokens in CTC style decoding + grouped_tokens = [token_group[0] for token_group in groupby(tokens)] + + # filter self.pad_token which is used as CTC-blank token + filtered_tokens = list(filter(lambda token: token != self.pad_token, grouped_tokens)) + + # replace delimiter token + string = "".join([" " if token == self.word_delimiter_token else token for token in filtered_tokens]).strip() + + if self.do_lower_case: + string = string.lower() + return string + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + ) -> str: + """ + special _decode function is needed for DETRTokenizer because added tokens should be treated exactly the + same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on + the whole token list and not individually on added tokens + """ + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + result = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + result.append(token) + + text = self.convert_tokens_to_string(result) + + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text \ No newline at end of file diff --git a/src/transformers/models/detr/tokenization_detr_fast.py b/src/transformers/models/detr/tokenization_detr_fast.py new file mode 100644 index 00000000000000..9ef37a723fbda4 --- /dev/null +++ b/src/transformers/models/detr/tokenization_detr_fast.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DETR.""" +from typing import List, Optional + +from tokenizers import ByteLevelBPETokenizer + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_detr import DetrTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {} + +PRETRAINED_VOCAB_FILES_MAP = {} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "facebook/detr-resnet-50": 1024, +} + +class DetrTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" DETR tokenizer (backed by HuggingFace's `tokenizers` library). + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = DetrTokenizer + + def __init__( + self, + vocab_file, + merges_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_prefix_space=False, + trim_offsets=True, + **kwargs + ): + super().__init__( + ByteLevelBPETokenizer( + vocab_file=vocab_file, + merges_file=merges_file, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + ), + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + **kwargs, + ) + self.add_prefix_space = add_prefix_space + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + DETR does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + + diff --git a/tests/test_modeling_detr.py b/tests/test_modeling_detr.py new file mode 100644 index 00000000000000..05b388b27588f0 --- /dev/null +++ b/tests/test_modeling_detr.py @@ -0,0 +1,378 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch DETR model. """ + + +import copy +import tempfile +import unittest + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, ids_tensor + +from PIL import Image +import requests + + +if is_torch_available(): + import torch + import torchvision.transforms as T + + from transformers import ( + DetrConfig, + DetrModel, + DetrTokenizer, + DetrForObjectDetection, + ) + from transformers.models.detr.modeling_detr import ( + DetrDecoder, + DetrEncoder, + ) + + +def prepare_detr_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + + +# @require_torch +# class DetrModelTester: +# def __init__( +# self, +# parent, +# batch_size=13, +# seq_length=7, +# is_training=True, +# use_labels=False, +# vocab_size=99, +# hidden_size=16, +# num_hidden_layers=2, +# num_attention_heads=4, +# intermediate_size=4, +# hidden_act="gelu", +# hidden_dropout_prob=0.1, +# attention_probs_dropout_prob=0.1, +# max_position_embeddings=20, +# eos_token_id=2, +# pad_token_id=1, +# bos_token_id=0, +# ): +# self.parent = parent +# self.batch_size = batch_size +# self.seq_length = seq_length +# self.is_training = is_training +# self.use_labels = use_labels +# self.vocab_size = vocab_size +# self.hidden_size = hidden_size +# self.num_hidden_layers = num_hidden_layers +# self.num_attention_heads = num_attention_heads +# self.intermediate_size = intermediate_size +# self.hidden_act = hidden_act +# self.hidden_dropout_prob = hidden_dropout_prob +# self.attention_probs_dropout_prob = attention_probs_dropout_prob +# self.max_position_embeddings = max_position_embeddings +# self.eos_token_id = eos_token_id +# self.pad_token_id = pad_token_id +# self.bos_token_id = bos_token_id + +# def prepare_config_and_inputs(self): +# input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) +# input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( +# 3, +# ) +# input_ids[:, -1] = self.eos_token_id # Eos Token + +# decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + +# config = DetrConfig( +# vocab_size=self.vocab_size, +# d_model=self.hidden_size, +# encoder_layers=self.num_hidden_layers, +# decoder_layers=self.num_hidden_layers, +# encoder_attention_heads=self.num_attention_heads, +# decoder_attention_heads=self.num_attention_heads, +# encoder_ffn_dim=self.intermediate_size, +# decoder_ffn_dim=self.intermediate_size, +# dropout=self.hidden_dropout_prob, +# attention_dropout=self.attention_probs_dropout_prob, +# max_position_embeddings=self.max_position_embeddings, +# eos_token_id=self.eos_token_id, +# bos_token_id=self.bos_token_id, +# pad_token_id=self.pad_token_id, +# ) +# inputs_dict = prepare_detr_inputs_dict(config, input_ids, decoder_input_ids) +# return config, inputs_dict + +# def prepare_config_and_inputs_for_common(self): +# config, inputs_dict = self.prepare_config_and_inputs() +# return config, inputs_dict + +# def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): +# model = DetrModel(config=config).get_decoder().to(torch_device).eval() +# input_ids = inputs_dict["input_ids"] +# attention_mask = inputs_dict["attention_mask"] + +# # first forward pass +# outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + +# output, past_key_values = outputs.to_tuple() + +# # create hypothetical multiple next token and extent to next_input_ids +# next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) +# next_attn_mask = ids_tensor((self.batch_size, 3), 2) + +# # append to next input_ids and +# next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) +# next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + +# output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] +# output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)["last_hidden_state"] + +# # select random slice +# random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() +# output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() +# output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + +# self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + +# # test that outputs are equal for slice +# self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + +# def check_encoder_decoder_model_standalone(self, config, inputs_dict): +# model = DetrModel(config=config).to(torch_device).eval() +# outputs = model(**inputs_dict) + +# encoder_last_hidden_state = outputs.encoder_last_hidden_state +# last_hidden_state = outputs.last_hidden_state + +# with tempfile.TemporaryDirectory() as tmpdirname: +# encoder = model.get_encoder() +# encoder.save_pretrained(tmpdirname) +# encoder = DetrEncoder.from_pretrained(tmpdirname).to(torch_device) + +# encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ +# 0 +# ] + +# self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + +# with tempfile.TemporaryDirectory() as tmpdirname: +# decoder = model.get_decoder() +# decoder.save_pretrained(tmpdirname) +# decoder = DetrDecoder.from_pretrained(tmpdirname).to(torch_device) + +# last_hidden_state_2 = decoder( +# input_ids=inputs_dict["decoder_input_ids"], +# attention_mask=inputs_dict["decoder_attention_mask"], +# encoder_hidden_states=encoder_last_hidden_state, +# encoder_attention_mask=inputs_dict["attention_mask"], +# )[0] + +# self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +# @require_torch +# class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +# all_model_classes = ( +# (DetrModel, DetrForObjectDetection,) +# if is_torch_available() +# else () +# ) +# #all_generative_model_classes = (DetrForConditionalGeneration,) if is_torch_available() else () +# is_encoder_decoder = True +# test_pruning = False +# test_head_masking = False +# test_missing_keys = False + +# def setUp(self): +# self.model_tester = DetrModelTester(self) +# self.config_tester = ConfigTester(self, config_class=DetrConfig) + +# def test_config(self): +# self.config_tester.run_common_tests() + +# def test_save_load_strict(self): +# config, inputs_dict = self.model_tester.prepare_config_and_inputs() +# for model_class in self.all_model_classes: +# model = model_class(config) + +# with tempfile.TemporaryDirectory() as tmpdirname: +# model.save_pretrained(tmpdirname) +# model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) +# self.assertEqual(info["missing_keys"], []) + +# def test_decoder_model_past_with_large_inputs(self): +# config_and_inputs = self.model_tester.prepare_config_and_inputs() +# self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + +# def test_encoder_decoder_model_standalone(self): +# config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() +# self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + +# # DetrForSequenceClassification does not support inputs_embeds +# def test_inputs_embeds(self): +# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + +# for model_class in (DetrModel, DetrForConditionalGeneration, DetrForQuestionAnswering): +# model = model_class(config) +# model.to(torch_device) +# model.eval() + +# inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + +# if not self.is_encoder_decoder: +# input_ids = inputs["input_ids"] +# del inputs["input_ids"] +# else: +# encoder_input_ids = inputs["input_ids"] +# decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) +# del inputs["input_ids"] +# inputs.pop("decoder_input_ids", None) + +# wte = model.get_input_embeddings() +# if not self.is_encoder_decoder: +# inputs["inputs_embeds"] = wte(input_ids) +# else: +# inputs["inputs_embeds"] = wte(encoder_input_ids) +# inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + +# with torch.no_grad(): +# model(**inputs)[0] + + +def assert_tensors_close(a, b, atol=1e-12, prefix=""): + """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item() + if a.numel() > 100: + msg = f"tensor values are {pct_different:.1%} percent different." + else: + msg = f"{a} != {b}" + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) + + +TOLERANCE = 1e-4 + + +# We will verify our outputs against the original implementation on an image of cute cats +def prepare_img(): + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + im = Image.open(requests.get(url, stream=True).raw) + + # standard PyTorch mean-std input image normalization + transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0) + + return img + + +def prepare_detr_inputs(): + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + im = Image.open(requests.get(url, stream=True).raw) + + tokenizer = DetrTokenizer() + + encoding = tokenizer(im) + + return encoding + + +@require_torch +@slow +class DetrModelIntegrationTests(unittest.TestCase): + # @cached_property + # def default_tokenizer(self): + # return DetrTokenizer.from_pretrained('facebook/detr-resnet-50') + + def test_inference_no_head(self): + model = DetrModel.from_pretrained('nielsr/detr-resnet-50-new').to(torch_device) + model.eval() + + encoding = prepare_detr_inputs() + pixel_values = encoding['pixel_values'].to(torch_device) + pixel_mask = encoding['pixel_mask'].to(torch_device) + + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_shape = torch.Size((1, 100, 256)) + assert outputs.last_hidden_state.shape == expected_shape + expected_slice = torch.tensor([[0.0616, -0.5146, -0.4032], + [-0.7629, -0.4934, -1.7153], + [-0.4768, -0.6403, -0.7826]]).to(torch_device) + self.assertTrue(torch.allclose(outputs.last_hidden_state[0,:3,:3], expected_slice, atol=1e-4)) + + + def test_inference_object_detection_head(self): + model = DetrForObjectDetection.from_pretrained('nielsr/detr-resnet-50-new').to(torch_device) + model.eval() + + encoding = prepare_detr_inputs() + pixel_values = encoding['pixel_values'].to(torch_device) + pixel_mask = encoding['pixel_mask'].to(torch_device) + + with torch.no_grad(): + outputs = model(pixel_values, pixel_mask) + + expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels + 1)) + self.assertEqual(outputs.pred_logits.shape, expected_shape_logits) + expected_slice_logits = torch.tensor([[-19.1194, -0.0893, -11.0154], + [-17.3640, -1.8035, -14.0219], + [-20.0461, -0.5837, -11.1060]]).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_logits[0,:3,:3], expected_slice_logits, atol=1e-4)) + + expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + expected_slice_boxes = torch.tensor([[0.4433, 0.5302, 0.8853], + [0.5494, 0.2517, 0.0529], + [0.4998, 0.5360, 0.9956]]).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0,:3,:3], expected_slice_boxes, atol=1e-4)) \ No newline at end of file diff --git a/tests/test_tokenization_detr.py b/tests/test_tokenization_detr.py new file mode 100644 index 00000000000000..af65088b4a97b4 --- /dev/null +++ b/tests/test_tokenization_detr.py @@ -0,0 +1,145 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Wav2Vec2 tokenizer.""" + +import inspect +import json +import os +import random +import shutil +import tempfile +import unittest + +import torch +import numpy as np + +from PIL import Image +import requests + +from transformers.models.detr.tokenization_detr import DetrTokenizer + +class DetrTokenizerTest(unittest.TestCase): + tokenizer_class = DetrTokenizer + + def setUp(self): + super().setUp() + + # single PIL image + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + self.img = Image.open(requests.get(url, stream=True).raw) + + # batch of PIL images + annotations + base_url = "http://images.cocodataset.org/val2017/" + image_urls = ["000000087038.jpg", "000000578500.jpg", "000000261982.jpg"] + + images = [] + for image_url in image_urls: + images.append(Image.open(requests.get(base_url + image_url, stream=True).raw)) + self.images = images + + # each target is a dict with keys "boxes" and "area" + self.annotations = [{'boxes': [[253.21, 271.07, 59.59, 60.97], + [226.04, 229.31, 11.59, 30.41], + [257.85, 224.48, 44.13, 97.0], + [68.18, 238.19, 16.18, 42.88], + [79.16, 232.26, 28.22, 51.12], + [98.4, 234.28, 19.52, 46.46], + [326.86, 223.46, 13.11, 38.67], + [155.27, 246.34, 14.87, 21.99], + [298.61, 316.85, 63.91, 47.19], + [345.41, 173.41, 72.94, 185.41], + [239.72, 225.38, 10.64, 33.06], + [167.02, 234.0, 15.78, 37.46], + [209.68, 231.08, 9.15, 34.53], + [408.29, 231.25, 17.12, 34.97], + [204.14, 229.02, 7.33, 34.96], + [195.32, 228.06, 10.65, 37.18], + [1, 190, 638, 101]], + 'area': [1391.4269500000005, + 232.4970999999999, + 1683.128300000001, + 413.4482999999996, + 563.41615, + 363.24569999999994, + 261.10905000000054, + 152.3124499999996, + 1268.7499999999989, + 4686.905750000002, + 204.17735000000013, + 277.0192999999997, + 241.29070000000024, + 243.31384999999952, + 188.82489999999987, + 294.38859999999977, + 6443], + }, + {'boxes': [[268.66, 143.28, 61.01, 53.52], + [204.04, 139.88, 24.87, 36.95], + [157.51, 135.92, 26.51, 54.95], + [117.06, 135.86, 37.88, 56.7], + [192.39, 137.85, 14.19, 46.86], + [311.46, 149.17, 156.95, 88.98], + [499.59, 116.56, 140.41, 173.44], + [1.86, 147.85, 132.27, 99.36], + [124.21, 150.01, 89.08, 5.36], + [344.97, 92.63, 6.72, 30.25], + [441.77, 71.9, 10.01, 41.52], + [118.63, 153.62, 8.78, 20.32], + [291.45, 179.46, 15.1, 10.3], + [498.7, 115.61, 141.3, 174.39]], + }, + {'boxes': [[0.0, 8.53, 251.9, 214.62], + [409.89, 120.81, 47.11, 92.04], + [84.85, 0.0, 298.6, 398.22], + [159.71, 211.53, 189.89, 231.01], + [357.69, 110.26, 99.31, 90.43]], + } + ] + + def get_tokenizer(self, **kwars): + return DetrTokenizer() + + # tests on single PIL image (inference only) + def test_tokenizer(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.img) + + self.assertEqual(encoding["pixel_values"].shape, (1, 3, 800, 1066)) + self.assertEqual(encoding["pixel_mask"].shape, (1, 800, 1066)) + + # tests on single PIL image (inference only, with resize set to False) + def test_tokenizer_no_resize(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.img, resize=False) + + self.assertEqual(encoding["pixel_values"].shape, (1, 3, 480, 640)) + self.assertEqual(encoding["pixel_mask"].shape, (1, 480, 640)) + + # tests on batch of PIL images (inference only) + def test_tokenizer_batch(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.images) + + self.assertEqual(encoding["pixel_values"].shape, (3, 3, 1120, 1332)) + self.assertEqual(encoding["pixel_mask"].shape, (3, 1120, 1332)) + + # tests on batch of PIL images (training, i.e. with annotations) + def test_tokenizer_batch_training(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.images, self.annotations) + + self.assertEqual(encoding["pixel_values"].shape, (3, 3, 1120, 1332)) + self.assertEqual(encoding["pixel_mask"].shape, (3, 1120, 1332)) + diff --git a/utils/check_repo.py b/utils/check_repo.py index c8881baa651d9b..9b6e574bdbe144 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -30,6 +30,8 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = [ # models to ignore for not tested +"DetrEncoder", # Building part of bigger (tested) model. + "DetrDecoder", # Building part of bigger (tested) model. "LEDEncoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model. "BartDecoderWrapper", # Building part of bigger (tested) model. @@ -75,6 +77,8 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = [ # models to ignore for model xxx mapping +"DetrEncoder", + "DetrDecoder", "LEDEncoder", "LEDDecoder", "BartDecoder",