From 3054168c35c71a5b43f4294fb6fec86728e25629 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 4 Feb 2021 09:59:33 +0100 Subject: [PATCH 01/20] First commit (copy files from modeling_detr_v3) --- docs/source/model_doc/detr.rst | 80 + src/transformers/__init__.py | 18 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 4 + src/transformers/models/auto/modeling_auto.py | 12 + src/transformers/models/detr/__init__.py | 71 + .../models/detr/configuration_detr.py | 224 ++ ..._original_pytorch_checkpoint_to_pytorch.py | 293 +++ src/transformers/models/detr/modeling_detr.py | 2060 +++++++++++++++++ .../models/detr/tokenization_detr.py | 250 ++ .../models/detr/tokenization_detr_fast.py | 106 + tests/test_modeling_detr.py | 371 +++ utils/check_repo.py | 4 + 13 files changed, 3494 insertions(+) create mode 100644 docs/source/model_doc/detr.rst create mode 100644 src/transformers/models/detr/__init__.py create mode 100644 src/transformers/models/detr/configuration_detr.py create mode 100644 src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py create mode 100644 src/transformers/models/detr/modeling_detr.py create mode 100644 src/transformers/models/detr/tokenization_detr.py create mode 100644 src/transformers/models/detr/tokenization_detr_fast.py create mode 100644 tests/test_modeling_detr.py diff --git a/docs/source/model_doc/detr.rst b/docs/source/model_doc/detr.rst new file mode 100644 index 00000000000000..2fe08238844822 --- /dev/null +++ b/docs/source/model_doc/detr.rst @@ -0,0 +1,80 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +DETR +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The DETR model was proposed in ` +<>`__ by . + +The abstract from the paper is the following: + +** + +Tips: + + + +DetrConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrConfig + :members: + + +DetrTokenizer +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrTokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +DetrTokenizerFast +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrTokenizerFast + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +DetrModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrModel + :members: forward + + +DetrForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrForConditionalGeneration + :members: forward + + +DetrForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrForSequenceClassification + :members: forward + + +DetrForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DetrForQuestionAnswering + :members: forward + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 41b18ad6d41394..8f24ebc61d289e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -134,6 +134,7 @@ "Wav2Vec2FeatureExtractor", "Wav2Vec2Processor", ], + "models.detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrTokenizer"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], "models.auto": [ @@ -288,6 +289,7 @@ # tokenziers-backed objects if is_tokenizers_available(): # Fast tokenizers + _import_structure["models.detr"].append("DetrTokenizerFast") _import_structure["models.convbert"].append("ConvBertTokenizerFast") _import_structure["models.albert"].append("AlbertTokenizerFast") _import_structure["models.bart"].append("BartTokenizerFast") @@ -376,6 +378,14 @@ _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] # PyTorch models structure + _import_structure["models.detr"].extend( + [ + "DETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "DetrForObjectDetection", + "DetrModel", + ] + ) + _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1297,6 +1307,7 @@ load_tf2_weights_in_pytorch_model, ) from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig + from .models.detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig, DetrTokenizer from .models.auto import ( ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, @@ -1452,6 +1463,7 @@ from .utils.dummy_sentencepiece_objects import * if is_tokenizers_available(): + from .models.detr import DetrTokenizerFast from .models.albert import AlbertTokenizerFast from .models.bart import BartTokenizerFast from .models.barthez import BarthezTokenizerFast @@ -1491,6 +1503,12 @@ # Modeling if is_torch_available(): + from .models.detr import ( + DETR_PRETRAINED_MODEL_ARCHIVE_LIST, + DetrForObjectDetection, + DetrModel, + ) + # Benchmarks from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark_args import PyTorchBenchmarkArguments diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index f7f9a9e58ded44..3189b1e8156875 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -17,6 +17,7 @@ # limitations under the License. from . import ( + detr, albert, auto, bart, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 338f273757573b..23606e1a934869 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -19,6 +19,7 @@ from ...configuration_utils import PretrainedConfig from ..albert.configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig +from ..detr.configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig from ..bart.configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from ..bert.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig from ..bert_generation.configuration_bert_generation import BertGenerationConfig @@ -75,6 +76,7 @@ (key, value) for pretrained_map in [ # Add archive maps here + DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LED_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -120,6 +122,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("detr", DetrConfig), ("wav2vec2", Wav2Vec2Config), ("convbert", ConvBertConfig), ("led", LEDConfig), @@ -171,6 +174,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("detr", "Detr"), ("wav2vec2", "Wav2Vec2"), ("convbert", "ConvBERT"), ("led", "LED"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 39e8b70b3ce1ed..67ffcd12758e20 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -23,6 +23,10 @@ from ...utils import logging # Add modeling imports here +from ..detr.modeling_detr import ( + DetrForObjectDetection, + DetrModel, +) from ..albert.modeling_albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, @@ -68,6 +72,10 @@ ) # Add modeling imports here +from ..detr.modeling_detr import ( + DetrForObjectDetection, + DetrModel, +) from ..convbert.modeling_convbert import ( ConvBertForMaskedLM, ConvBertForMultipleChoice, @@ -258,6 +266,7 @@ XLNetModel, ) from .configuration_auto import ( + DetrConfig, AlbertConfig, AutoConfig, BartConfig, @@ -313,6 +322,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping + (DetrConfig, DetrModel), (Wav2Vec2Config, Wav2Vec2Model), (ConvBertConfig, ConvBertModel), (LEDConfig, LEDModel), @@ -396,6 +406,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( [ # Model with LM heads mapping + (Wav2Vec2Config, Wav2Vec2ForMaskedLM), (ConvBertConfig, ConvBertForMaskedLM), (LEDConfig, LEDForConditionalGeneration), @@ -495,6 +506,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + (LEDConfig, LEDForConditionalGeneration), (BlenderbotSmallConfig, BlenderbotSmallForConditionalGeneration), (MT5Config, MT5ForConditionalGeneration), diff --git a/src/transformers/models/detr/__init__.py b/src/transformers/models/detr/__init__.py new file mode 100644 index 00000000000000..182650212c9361 --- /dev/null +++ b/src/transformers/models/detr/__init__.py @@ -0,0 +1,71 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING +from ...file_utils import _BaseLazyModule, is_torch_available, is_tokenizers_available +_import_structure = { + "configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"], + "tokenization_detr": ["DetrTokenizer"], +} + +if is_tokenizers_available(): + _import_structure["tokenization_detr_fast"] = ["DetrTokenizerFast"] + +if is_torch_available(): + _import_structure["modeling_detr"] = [ + "DETR_PRETRAINED_MODEL_ARCHIVE_LIST", + "DetrForObjectDetection", + "DetrModel", + "DetrPreTrainedModel", + ] + + + + +if TYPE_CHECKING: + from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig + from .tokenization_detr import DetrTokenizer + + if is_tokenizers_available(): + from .tokenization_detr_fast import DetrTokenizerFast + + if is_torch_available(): + from .modeling_detr import ( + DETR_PRETRAINED_MODEL_ARCHIVE_LIST, + DetrForObjectDetection, + DetrModel, + DetrPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py new file mode 100644 index 00000000000000..6b9c76a476c760 --- /dev/null +++ b/src/transformers/models/detr/configuration_detr.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DETR model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DETR_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/detr-resnet-50": "https://huggingface.co/facebook/detr-resnet-50/resolve/main/config.json", + # See all DETR models at https://huggingface.co/models?filter=detr +} + + +class DetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.DetrModel`. + It is used to instantiate a DETR model according to the specified arguments, defining the model + architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of + the DETR `facebook/detr-resnet-50 `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used + to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` + for more information. + + + Args: + num_queries (:obj:`int`, `optional`, defaults to 100): + Number of object queries, i.e. detection slots. This is the maximal number of objects + :class:`~transformers.DetrModel` can detect in a single image. For COCO, we recommend 100 queries. + d_model (:obj:`int`, `optional`, defaults to 256): + Dimensionality of the layers. + encoder_layers (:obj:`int`, `optional`, defaults to 6): + Number of encoder layers. + decoder_layers (:obj:`int`, `optional`, defaults to 6): + Number of decoder layers. + encoder_attention_heads (:obj:`int`, `optional`, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_attention_heads (:obj:`int`, `optional`, defaults to 8): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + encoder_ffn_dim (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported. + dropout (:obj:`float`, `optional`, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + classifier_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for classifier. + max_position_embeddings (:obj:`int`, `optional`, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the encoder. See the `LayerDrop paper `__ for more details. + decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): + The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). + auxiliary_loss (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether auxiliary decoding losses (loss at each decoder layer) are to be used. + position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`sine`): + Type of position embeddings to be used on top of the image features. One of 'sine' or 'learned'. + backbone (:obj:`bool`, `optional`, defaults to :obj:`resnet50`): + Name of convolutional backbone to use. Currently only resnet of the Torchvision package is supported. + train_backbone (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to train (fine-tune) the backbone. + dilation (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to replace stride with dilation in the last convolutional block (DC5). + masks (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to train the segmentation head. + class_cost (:obj:`float`, `optional`, defaults to 1): + Relative weight of the classification error in the Hungarian matching cost. + bbox_cost (:obj:`float`, `optional`, defaults to 5): + Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost. + giou_cost (:obj:`float`, `optional`, defaults to 2): + Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost. + mask_loss_coefficient (:obj:`float`, `optional`, defaults to 1): + Relative weight of the Focal loss in the panoptic segmentation loss. + dice_loss_coefficient (:obj:`float`, `optional`, defaults to 1): + Relative weight of the DICE/F-1 loss in the panoptic segmentation loss. + bbox_loss_coefficient (:obj:`float`, `optional`, defaults to 5): + Relative weight of the L1 bounding box loss in the object detection loss. + giou_loss_coefficient (:obj:`float`, `optional`, defaults to 2): + Relative weight of the generalized IoU loss in the object detection loss. + eos_coefficient (:obj:`float`, `optional`, defaults to 0.1): + Relative classification weight of the 'no-object' class in the object detection loss. + Example:: + + >>> from transformers import DetrModel, DetrConfig + + >>> # Initializing a DETR facebook/detr-resnet-50 style configuration + >>> configuration = DetrConfig() + + >>> # Initializing a model from the facebook/detr-resnet-50 style configuration + >>> model = DetrModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "detr" + keys_to_ignore_at_inference = ["past_key_values"] + def __init__( + self, + num_queries=100, + max_position_embeddings=1024, + encoder_layers=6, + encoder_ffn_dim=2048, + encoder_attention_heads=8, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=8, + encoder_layerdrop=0.0, + decoder_layerdrop=0.0, + use_cache=True, + is_encoder_decoder=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + classifier_dropout=0.0, + scale_embedding=False, + gradient_checkpointing=False, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + auxiliary_loss=False, + position_embedding_type='sine', + backbone='resnet50', + train_backbone=True, + dilation=False, + masks=False, + class_cost=1, + bbox_cost=5, + giou_cost=2, + mask_loss_coefficient=1, + dice_loss_coefficient=1, + bbox_loss_coefficient=5, + giou_loss_coefficient=2, + eos_coefficient=0.1, + **kwargs + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + decoder_start_token_id=decoder_start_token_id, + **kwargs + ) + + self.num_queries = num_queries + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.decoder_layerdrop = decoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.gradient_checkpointing = gradient_checkpointing + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.auxiliary_loss = auxiliary_loss + self.position_embedding_type = position_embedding_type + self.backbone = backbone + self.train_backbone = train_backbone + self.dilation = dilation + self.masks = masks + # Hungarian matcher + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + # Loss coefficients + self.mask_loss_coefficient = mask_loss_coefficient + self.dice_loss_coefficient = dice_loss_coefficient + self.bbox_loss_coefficient = bbox_loss_coefficient + self.giou_loss_coefficient = giou_loss_coefficient + self.eos_coefficient = eos_coefficient + + + + @property + def num_attention_heads(self) -> int: + return self.encoder_attention_heads + + @property + def hidden_size(self) -> int: + return self.d_model \ No newline at end of file diff --git a/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 00000000000000..27c4cc3cc2a085 --- /dev/null +++ b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,293 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DETR checkpoints.""" + + +import argparse +from pathlib import Path + +import torch +import torchvision.transforms as T +from packaging import version + +from PIL import Image +import requests + +from transformers import ( + DetrConfig, + DetrModel, + DetrForObjectDetection, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +# here we list all keys to be renamed (original name on the left, our name on the right) +rename_keys = [] +for i in range(6): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append(("transformer.encoder.layers." + str(i) + ".self_attn.out_proj.weight", "encoder.layers." + str(i) + ".self_attn.out_proj.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".self_attn.out_proj.bias", "encoder.layers." + str(i) + ".self_attn.out_proj.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear1.weight", "encoder.layers." + str(i) + ".fc1.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear1.bias", "encoder.layers." + str(i) + ".fc1.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear2.weight", "encoder.layers." + str(i) + ".fc2.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".linear2.bias", "encoder.layers." + str(i) + ".fc2.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm1.weight", "encoder.layers." + str(i) + ".self_attn_layer_norm.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm1.bias", "encoder.layers." + str(i) + ".self_attn_layer_norm.bias")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm2.weight", "encoder.layers." + str(i) + ".final_layer_norm.weight")) + rename_keys.append(("transformer.encoder.layers." + str(i) + ".norm2.bias", "encoder.layers." + str(i) + ".final_layer_norm.bias")) + # decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms + rename_keys.append(("transformer.decoder.layers." + str(i) + ".self_attn.out_proj.weight", "decoder.layers." + str(i) + ".self_attn.out_proj.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".self_attn.out_proj.bias","decoder.layers." + str(i) + ".self_attn.out_proj.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".multihead_attn.out_proj.weight", "decoder.layers." + str(i) + ".encoder_attn.out_proj.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".multihead_attn.out_proj.bias", "decoder.layers." + str(i) + ".encoder_attn.out_proj.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear1.weight", "decoder.layers." + str(i) + ".fc1.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear1.bias", "decoder.layers." + str(i) + ".fc1.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear2.weight", "decoder.layers." + str(i) + ".fc2.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".linear2.bias", "decoder.layers." + str(i) + ".fc2.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm1.weight", "decoder.layers." + str(i) + ".self_attn_layer_norm.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm1.bias", "decoder.layers." + str(i) + ".self_attn_layer_norm.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm2.weight", "decoder.layers." + str(i) + ".encoder_attn_layer_norm.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm2.bias", "decoder.layers." + str(i) + ".encoder_attn_layer_norm.bias")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm3.weight", "decoder.layers." + str(i) + ".final_layer_norm.weight")) + rename_keys.append(("transformer.decoder.layers." + str(i) + ".norm3.bias", "decoder.layers." + str(i) + ".final_layer_norm.bias")) + + +# convolutional projection + query embeddings + layernorm of decoder +rename_keys.extend([("input_proj.weight", "input_projection.weight"), +("input_proj.bias", "input_projection.bias"), +("query_embed.weight", "query_position_embeddings.weight"), +("transformer.decoder.norm.weight", "decoder.layernorm.weight"), +("transformer.decoder.norm.bias", "decoder.layernorm.bias")]) + + +def remove_object_detection_heads_(state_dict): + ignore_keys = [ + "class_embed.weight", + "class_embed.bias", + "bbox_embed.layers.0.weight", + "bbox_embed.layers.0.bias", + "bbox_embed.layers.1.weight", + "bbox_embed.layers.1.bias", + "bbox_embed.layers.2.weight", + "bbox_embed.layers.2.bias" + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(state_dict, old, new): + val = state_dict.pop(old) + state_dict[new] = val + + +def read_in_q_k_v(state_dict): + # first: transformer encoder + for i in range(6): + # read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias) + in_proj_weight = state_dict.pop("transformer.encoder.layers." + str(i) + ".self_attn.in_proj_weight") + in_proj_bias = state_dict.pop("transformer.encoder.layers." + str(i) + ".self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict["encoder.layers." + str(i) + ".self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict["encoder.layers." + str(i) + ".self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict["encoder.layers." + str(i) + ".self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict["encoder.layers." + str(i) + ".self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict["encoder.layers." + str(i) + ".self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict["encoder.layers." + str(i) + ".self_attn.v_proj.bias"] = in_proj_bias[-256:] + # next: transformer decoder (which is a bit more complex because it also includes cross-attention) + for i in range(6): + # read in weights + bias of input projection layer of self-attention + in_proj_weight = state_dict.pop("transformer.decoder.layers." + str(i) + ".self_attn.in_proj_weight") + in_proj_bias = state_dict.pop("transformer.decoder.layers." + str(i) + ".self_attn.in_proj_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict["decoder.layers." + str(i) + ".self_attn.q_proj.weight"] = in_proj_weight[:256, :] + state_dict["decoder.layers." + str(i) + ".self_attn.q_proj.bias"] = in_proj_bias[:256] + state_dict["decoder.layers." + str(i) + ".self_attn.k_proj.weight"] = in_proj_weight[256:512, :] + state_dict["decoder.layers." + str(i) + ".self_attn.k_proj.bias"] = in_proj_bias[256:512] + state_dict["decoder.layers." + str(i) + ".self_attn.v_proj.weight"] = in_proj_weight[-256:, :] + state_dict["decoder.layers." + str(i) + ".self_attn.v_proj.bias"] = in_proj_bias[-256:] + # read in weights + bias of input projection layer of cross-attention + in_proj_weight_cross_attn = state_dict.pop("transformer.decoder.layers." + str(i) + ".multihead_attn.in_proj_weight") + in_proj_bias_cross_attn = state_dict.pop("transformer.decoder.layers." + str(i) + ".multihead_attn.in_proj_bias") + # next, add query, keys and values (in that order) of cross-attention to the state dict + state_dict["decoder.layers." + str(i) + ".encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :] + state_dict["decoder.layers." + str(i) + ".encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256] + state_dict["decoder.layers." + str(i) + ".encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :] + state_dict["decoder.layers." + str(i) + ".encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512] + state_dict["decoder.layers." + str(i) + ".encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :] + state_dict["decoder.layers." + str(i) + ".encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:] + + +# since we renamed the classification heads of the object detection model, we need to rename the original keys: +rename_keys_object_detection_model = [ +("class_embed.weight", "class_labels_classifier.weight"), +("class_embed.bias", "class_labels_classifier.bias"), +("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"), +("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"), +("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"), +("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"), +("bbox_embed.layers.2.weight","bbox_predictor.layers.2.weight"), +("bbox_embed.layers.2.bias","bbox_predictor.layers.2.bias"), +] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + im = Image.open(requests.get(url, stream=True).raw) + + # standard PyTorch mean-std input image normalization + transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0) + + return img + + +# COCO classes +CLASSES = [ + 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', + 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', + 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', + 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', + 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', + 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', + 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', + 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', + 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', + 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', + 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', + 'toothbrush' +] + + +@torch.no_grad() +def convert_detr_checkpoint(task, backbone, dilation, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our DETR structure. + """ + + config = DetrConfig() + img = prepare_img() + + logger.info(f"Converting model for task {task}, with a {backbone} backbone, dilation set to {dilation}...") + + if task == "base_model": + # load model from torch hub + detr = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).eval() + state_dict = detr.state_dict() + # rename keys + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # remove classification heads + remove_object_detection_heads_(state_dict) + # finally, create model and load state dict + model = DetrModel(config).eval() + model.load_state_dict(state_dict) + # verify our conversion on the image + outputs = model(img) + assert outputs.last_hidden_state.shape == (1, config.num_queries, config.d_model) + expected_slice = torch.tensor([[0.0616, -0.5146, -0.4032], + [-0.7629, -0.4934, -1.7153], + [-0.4768, -0.6403, -0.7826]]) + assert torch.allclose(outputs.last_hidden_state[0,:3,:3], expected_slice, atol=1e-4) + + elif task == "object_detection": + # coco has 91 labels + config.num_labels = 91 + config.id2label = {v: k for v, k in enumerate(CLASSES)} + config.label2id = {k: v for v, k in enumerate(CLASSES)} + # load model from torch hub + if backbone == 'resnet_50' and not dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True).eval() + elif backbone == 'resnet_50' and dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_dc5_resnet50', pretrained=True).eval() + config.dilation = True + elif backbone == 'resnet_101' and not dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_resnet101', pretrained=True).eval() + config.backbone = 'resnet_101' + elif backbone == 'resnet_101' and dilation: + detr = torch.hub.load('facebookresearch/detr', 'detr_dc5_resnet101', pretrained=True).eval() + config.backbone = 'resnet_101' + config.dilation = True + else: + raise ValueError(f"Not supported: {backbone} with {dilation}") + + state_dict = detr.state_dict() + # rename keys + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # query, key and value matrices need special treatment + read_in_q_k_v(state_dict) + # rename classification heads + for src, dest in rename_keys_object_detection_model: + rename_key(state_dict, src, dest) + # finally, create model and load state dict + model = DetrForObjectDetection(config).eval() + model.load_state_dict(state_dict) + # verify our conversion + original_outputs = detr(img) + outputs = model(img) + assert torch.allclose(outputs.pred_logits, original_outputs['pred_logits'], atol=1e-4) + assert torch.allclose(outputs.pred_boxes, original_outputs['pred_boxes'], atol=1e-4) + + elif task == "panoptic_segmentation": + # First, load in original detr from torch hub + if backbone == 'resnet_50' and not dilation: + detr, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_resnet50_panoptic', + pretrained=True, return_postprocessor=True, num_classes=250) + detr.eval() + elif backbone == 'resnet_50' and dilation: + detr, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_dc5_resnet50_panoptic', + pretrained=True, return_postprocessor=True, num_classes=250) + detr.eval() + config.dilation = True + elif backbone == 'resnet_101' and not dilation: + detr, postprocessor = torch.hub.load('facebookresearch/detr', 'detr_resnet101_panoptic', + pretrained=True, return_postprocessor=True, num_classes=250) + detr.eval() + config.backbone = 'resnet_101' + else: + print("Not supported:", backbone, dilation) + + else: + print("Task not in list of supported tasks:", task) + + # Save model + logger.info(f"Saving PyTorch model to {pytorch_dump_folder_path}...") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--task", default='base_model', type=str, help="""Task for which to convert a checkpoint. One of 'base_model', + 'object_detection' or 'panoptic_segmentation'.""") + parser.add_argument("--backbone", default='resnet_50', type=str, help="Which backbone to use. One of 'resnet50', 'resnet101'.") + parser.add_argument("--dilation", default=False, action="store_true", help="Whether to apply dilated convolution.") + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + args = parser.parse_args() + convert_detr_checkpoint(args.task, args.backbone, args.dilation, args.pytorch_dump_folder_path) \ No newline at end of file diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py new file mode 100644 index 00000000000000..ac5c7b19785b6c --- /dev/null +++ b/src/transformers/models/detr/modeling_detr.py @@ -0,0 +1,2060 @@ +# coding=utf-8 +# Copyright 2021 Facebook AI Research The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DETR model. """ + + +import math +import random +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, List +from scipy.optimize import linear_sum_assignment + +import torch +import torch.nn.functional as F +import torchvision +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.boxes import box_area +from torch import nn +from torch.nn import CrossEntropyLoss +from torch import Tensor +import torch.distributed as dist + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithCrossAttentions, + # BaseModelOutputWithPastAndCrossAttentions, (Niels): don't think we need this one as DETR uses parallel decoding + Seq2SeqModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import logging +from .configuration_detr import DetrConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DetrConfig" +_TOKENIZER_FOR_DOC = "DetrTokenizer" + + +DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/detr-resnet-50", + # See all DETR models at https://huggingface.co/models?filter=detr +] + + +@dataclass +class BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates(BaseModelOutputWithCrossAttentions): + """ + This class adds one attribute to BaseModelOutputWithCrossAttentions, namely an optional stack of intermediate decoder + activations, i.e. the output of each decoder layer, each of them gone through a layernorm. + Args: + intermediate_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(config.decoder_layers, batch_size, sequence_length, hidden_size)`): + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + +@dataclass +class DetrObjectDetectionOutput(ModelOutput): + """ + Output type of :class:`~transformers.DetrForObjectDetection`. + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` are provided)): + Total loss as the sum of (...). + pred_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, num_classes + 1)`): + Classification logits (including no-object) for all queries. + pred_boxes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_queries, 4)`): + Normalized boxes coordinates for all queries, represented as (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the + unnormalized bounding box. + auxiliary_outputs (:obj:`list[Dict]`, `optional`): + Optional, only returned when auxilary losses are activated (i.e. config.auxiliary_loss is set to True) and labels are provided. It is a + list of dictionnaries containing the two above keys (pred_logits and pred_boxes) for each decoder layer. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 tensors + of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + pred_logits: torch.FloatTensor = None + pred_boxes: torch.FloatTensor = None + auxiliary_outputs: Optional[List[Dict]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +## BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/misc.py + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + """ + Data type that handles different types of inputs (either list of images or list of sequences), + and computes the padded output (with masking). + """ + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('Not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +## BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/master/backbone.py + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it user-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=True, norm_layer=FrozenBatchNorm2d) + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) + + +class DetrSinePositionEmbedding(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.embedding_dim = embedding_dim + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.embedding_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class DetrLearnedPositionEmbedding(nn.Module): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + def __init__(self, embedding_dim=256): + super().__init__() + self.row_embeddings = nn.Embedding(50, embedding_dim) + self.column_embeddings = nn.Embedding(50, embedding_dim) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embeddings.weight) + nn.init.uniform_(self.column_embeddings.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.column_embeddings(i) + y_emb = self.row_embeddings(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(config): + N_steps = config.d_model // 2 + if config.position_embedding_type == 'sine': + # TODO find a better way of exposing other arguments + position_embedding = DetrSinePositionEmbedding(N_steps, normalize=True) + elif config.position_embedding_type == 'learned': + position_embedding = DetrLearnedPositionEmbedding(N_steps) + else: + raise ValueError(f"not supported {config.position_embedding_type}") + + return position_embedding + + +# class DetrLearnedPositionalEmbedding(nn.Embedding): +# """ +# This module learns positional embeddings up to a fixed maximum size. +# """ + +# def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): +# assert padding_idx is not None, "`padding_idx` should not be None, but of type int" +# super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) + +# def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): +# """`input_ids_shape` is expected to be [bsz x seqlen].""" +# bsz, seq_len = input_ids_shape[:2] +# positions = torch.arange( +# past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device +# ) +# return super().forward(positions) + + +class DetrAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. + + Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." + self.scaling = self.head_dim ** -0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + # added (Niels) + def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]): + return tensor if position_embeddings is None else tensor + position_embeddings + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + key_value_states: Optional[torch.Tensor] = None, + key_value_position_embeddings: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = hidden_states.size() + + # Added (Niels): add position embeddings to the hidden states before projecting to queries and keys + if position_embeddings is not None: + hidden_states_original = hidden_states + hidden_states = self.with_pos_embed(hidden_states, position_embeddings) + + # Added (Niels): add key-value position embeddings to the key value states + if key_value_position_embeddings is not None: + key_value_states_original = key_value_states + key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + assert attn_weights.size() == ( + bsz * self.num_heads, + tgt_len, + src_len, + ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" + + if attention_mask is not None: + assert attention_mask.size() == ( + bsz, + 1, + tgt_len, + src_len, + ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = F.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + assert attn_output.size() == ( + bsz * self.num_heads, + tgt_len, + self.head_dim, + ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" + + attn_output = ( + attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class DetrEncoderLayer(nn.Module): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor = None, + output_attentions: bool = False): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_embeddings (:obj:`torch.FloatTensor`, `optional`): position embeddings, to be added to hidden_states. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, + output_attentions=output_attentions + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class DetrDecoderLayer(nn.Module): + def __init__(self, config: DetrConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = DetrAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = DetrAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + query_position_embeddings: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (:obj:`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + position_embeddings=query_position_embeddings, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + key_value_position_embeddings=position_embeddings, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class DetrClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +class DetrPreTrainedModel(PreTrainedModel): + config_class = DetrConfig + base_model_prefix = "model" + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +DETR_START_DOCSTRING = r""" + This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic + methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior. + + Parameters: + config (:class:`~transformers.DetrConfig`): + Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +DETR_GENERATION_EXAMPLE = r""" + Summarization example:: + + >>> from transformers import DetrTokenizer, DetrForConditionalGeneration, DetrConfig + + >>> model = DetrForConditionalGeneration.from_pretrained('facebook/detr-resnet-50') + >>> tokenizer = DetrTokenizer.from_pretrained('facebook/detr-resnet-50') + + >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." + >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') + + >>> # Generate Summary + >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) + >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) +""" + +DETR_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using :class:`~transformers.DetrTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Provide for translation and summarization training. By default, the model will create this tensor by + shifting the :obj:`input_ids` to the right, following the paper. + decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): + Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will + also be used by default. + + If you want to change padding behavior, you should read :func:`modeling_detr._prepare_decoder_inputs` and + modify to your needs. See diagram 1 in `the paper `__ for more + information on the default strategy. + encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): + Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: + :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, + `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert :obj:`input_ids` indices into associated + vectors than the model's internal embedding lookup matrix. + decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded + representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` + have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert + :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` + takes the value of :obj:`inputs_embeds`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +class DetrEncoder(DetrPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + :class:`DetrEncoderLayer`. + + Args: + config: DetrConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: DetrConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + # if embed_tokens is not None: + # self.embed_tokens = embed_tokens + # else: + # self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + + # self.embed_positions = DetrLearnedPositionalEmbedding( + # config.max_position_embeddings, + # embed_dim, + # self.padding_idx, + # ) + + self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)]) + + # (Niels) in the original DETR, no layernorm is used for the Encoder, as "normalize_before" is set to False + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + position_embeddings=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.DetrTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + # elif input_ids is not None: + # input_shape = input_ids.size() + # input_ids = input_ids.view(-1, input_shape[-1]) + # elif inputs_embeds is not None: + # input_shape = inputs_embeds.size()[:-1] + # else: + # raise ValueError("You have to specify either input_ids or inputs_embeds") + + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # embed_pos = self.embed_positions(input_shape) + + # # add position embeddings + # hidden_states = inputs_embeds + embed_pos + # hidden_states = self.layernorm_embedding(hidden_states) + # hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + hidden_states = inputs_embeds + # (Niels) comment out layernorm, see __init__ above + #hidden_states = self.layernorm_embedding(hidden_states) + hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + ) + else: + # we add position_embeddings as extra input to the encoder_layer + layer_outputs = encoder_layer(hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class DetrDecoder(DetrPreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`DetrDecoderLayer`. + + Some small tweaks for DETR: + - position_embeddings and query_position_embeddings are added to the forward pass. + - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + + Args: + config: DetrConfig + embed_tokens (torch.nn.Embedding): output embedding + """ + + def __init__(self, config: DetrConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + # don't think we need embed_tokens (output tokens) here, since we are just updating the query embeddings + + # if embed_tokens is not None: + # self.embed_tokens = embed_tokens + # else: + # self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + # self.embed_positions = DetrLearnedPositionalEmbedding( + # config.max_position_embeddings, + # config.d_model, + # self.padding_idx, + # ) + self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)]) + # in DETR, the decoder uses layernorm after the last decoder layer output + self.layernorm = nn.LayerNorm(config.d_model) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + inputs_embeds=None, + position_embeddings=None, + query_position_embeddings=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using :class:`~transformers.DetrTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` + for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up + decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last + :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of + shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, + sequence_length)`. + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert :obj:`input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under + returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors + for more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + # if input_ids is not None and inputs_embeds is not None: + # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + # elif input_ids is not None: + # input_shape = input_ids.size() + # input_ids = input_ids.view(-1, input_shape[-1]) + # elif inputs_embeds is not None: + # input_shape = inputs_embeds.size()[:-1] + # else: + # raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + # to do: should be updated, because no input_ids here + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + # added this (Niels) to infer input_shape: + if inputs_embeds is not None: + hidden_states = inputs_embeds + input_shape = inputs_embeds.size()[:-1] + + combined_attention_mask = None + # (Niels): following lines are not required as DETR uses parallel decoding instead of autoregressive + # # create causal mask + # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # combined_attention_mask = None + + # if input_shape[-1] > 1: + # combined_attention_mask = _make_causal_mask( + # input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + # ).to(self.device) + + if attention_mask is not None and combined_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = combined_attention_mask + _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + + # (Niels): following lines are not required because adding position embeddings happens in DetrDecoderLayer + # embed positions + # positions = self.embed_positions(input_shape, past_key_values_length) + + # hidden_states = inputs_embeds + positions + # hidden_states = self.layernorm_embedding(inputs_embeds) + # hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) + + # (Niels): added an optional list: + intermediate = [] if self.config.auxiliary_loss else None + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + next_decoder_cache = () if use_cache else None + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, use_cache) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + combined_attention_mask, + encoder_hidden_states, + encoder_attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=combined_attention_mask, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if self.config.auxiliary_loss: + intermediate.append(self.layernorm(hidden_states)) + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + all_cross_attentions += (layer_outputs[2],) + + # finally, apply layernorm + hidden_states = self.layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # stack intermediate decoder activations + if self.config.auxiliary_loss: + intermediate = torch.stack(intermediate) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions, + intermediate] + if v is not None + ) + return BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates( + last_hidden_state=hidden_states, + #past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + intermediate_hidden_states=intermediate, + ) + + +@add_start_docstrings( + """The bare DETR Model (consisting of a backbone and encoder-decoder Transformer) outputting raw hidden-states without + any specific head on top.""", + DETR_START_DOCSTRING, +) +class DetrModel(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = Backbone(config.backbone, config.train_backbone, config.masks, config.dilation) + position_embeddings = build_position_encoding(config) + self.backbone = Joiner(backbone, position_embeddings) + + # Create projection layer + self.input_projection = nn.Conv2d(backbone.num_channels, config.d_model, kernel_size=1) + + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = DetrEncoder(config) + self.decoder = DetrDecoder(config) + + self.init_weights() + + # def get_input_embeddings(self): + # return self.shared + + # def set_input_embeddings(self, value): + # self.shared = value + # self.encoder.embed_tokens = self.shared + # self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="facebook/detr-resnet-50", + output_type=Seq2SeqModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + samples: NestedTensor=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through Backbone to obtain the features (includes features map, mask and position embeddings) + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, position_embeddings_list = self.backbone(samples) + + src, mask = features[-1].decompose() + assert mask is not None + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + src = self.input_projection(src) + + # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + batch_size, c, h, w = src.shape + src = src.flatten(2).permute(0, 2, 1) + position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1) + mask = ~mask.flatten(1) + + # Fourth, sent src + mask + position embeddings through encoder + # src is a Tensor of shape (batch_size, heigth*width, hidden_size) + # mask is a Tensor of shape (batch_size, heigth*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=src, + attention_mask=mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + tgt = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=tgt, + attention_mask=None, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + #past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """DETR Model (consisting of a backbone and encoder-decoder Transformer) with an object detection head on top, + for tasks such as COCO.""", + DETR_START_DOCSTRING, +) +class DetrForObjectDetection(DetrPreTrainedModel): + def __init__(self, config: DetrConfig): + super().__init__(config) + + # Create backbone + positional encoding + backbone = Backbone(config.backbone, config.train_backbone, config.masks, config.dilation) + position_embeddings = build_position_encoding(config) + self.backbone = Joiner(backbone, position_embeddings) + + # Create projection layer + self.input_projection = nn.Conv2d(backbone.num_channels, config.d_model, kernel_size=1) + + self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + + self.encoder = DetrEncoder(config) + self.decoder = DetrDecoder(config) + + # Object detection heads + self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels + 1) + self.bbox_predictor = MLP(input_dim=config.d_model, hidden_dim=config.d_model, + output_dim=4, num_layers=3) + + self.init_weights() + + # def get_input_embeddings(self): + # return self.shared + + # def set_input_embeddings(self, value): + # self.shared = value + # self.encoder.embed_tokens = self.shared + # self.decoder.embed_tokens = self.shared + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + # copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/detr-resnet-50", + # output_type=DetrObjectDetectionOutput, + # config_class=_CONFIG_FOR_DOC, + # ) + def forward( + self, + samples: NestedTensor=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`List[Dict]` of len :obj:`(batch_size,)`, `optional`): + Labels for computing the bipartite matching loss. List of dicts, each dictionary containing 2 keys: 'class_labels' and + 'boxes' (the class labels and bounding boxes of an image in the batch respectively). The class labels themselves should + be a :obj:`torch.LongTensor` of len :obj:`(number of bounding boxes in the image,)` and the boxes a :obj:`torch.FloatTensor` + of shape :obj:`(number of bounding boxes in the image, 4)`. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # First, sent images through Backbone to obtain the features (includes features map, mask and position embeddings) + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, position_embeddings_list = self.backbone(samples) + + src, mask = features[-1].decompose() + assert mask is not None + + # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) + src = self.input_projection(src) + + # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC + # In other words, turn their shape into (batch_size, sequence_length, hidden_size) + batch_size, c, h, w = src.shape + src = src.flatten(2).permute(0, 2, 1) + position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1) + mask = ~mask.flatten(1) + + # Fourth, sent src + mask + position embeddings through encoder + # src is a Tensor of shape (batch_size, heigth*width, hidden_size) + # mask is a Tensor of shape (batch_size, heigth*width) + if encoder_outputs is None: + encoder_outputs = self.encoder( + inputs_embeds=src, + attention_mask=mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # Fifth, sent queries i.e. tgt (initialized with zeros), query position embeddings + position embeddings + # through the decoder (which is conditioned on the encoder output) + query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) + tgt = torch.zeros_like(query_position_embeddings) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + inputs_embeds=tgt, + attention_mask=None, + position_embeddings=position_embeddings, + query_position_embeddings=query_position_embeddings, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # class logits + predicted bounding boxes + # to do: make this as efficient as the original implementation + pred_logits = self.class_labels_classifier(decoder_outputs[0]) + pred_boxes = self.bbox_predictor(decoder_outputs[0]).sigmoid() + + loss, auxiliary_outputs = None, None + if labels is not None: + # First: create the matcher + matcher = HungarianMatcher(class_cost=self.config.class_cost, + bbox_cost=self.config.bbox_cost, + giou_cost=self.config.giou_cost) + # Second: create the criterion + weight_dict = {'loss_ce': 1, 'loss_bbox': self.config.bbox_loss_coefficient} + weight_dict['loss_giou'] = self.config.giou_loss_coefficient + # to do: move the following three lines to DetrForPanopticSegmentation + if self.config.masks: + weight_dict["loss_mask"] = self.config.mask_loss_coef + weight_dict["loss_dice"] = self.config.dice_loss_coef + # TODO this is a hack (doesn't work yet) + if self.config.auxiliary_loss: + aux_weight_dict = {} + for i in range(self.config.decoder_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + # to do: move the following two lines to DetrForPanopticSegmentation + if self.config.masks: + losses += ["masks"] + # (copied from original repo in detr.py): + # the naming of the `num_classes` parameter of the criterion is somewhat misleading. + # it indeed corresponds to `max_obj_id + 1`, where max_obj_id + # is the maximum id for a class in your dataset. For example, + # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. + # As another example, for a dataset that has a single class with id 1, + # you should pass `num_classes` to be 2 (max_obj_id + 1). + # For more details on this, check the following discussion + # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 + criterion = SetCriterion(matcher=matcher, num_classes=self.config.num_labels, + weight_dict=weight_dict, eos_coef=self.config.eos_coefficient, losses=losses) + criterion.to(self.device) + # Third: compute the loss, based on outputs and labels + outputs = {} + outputs['pred_logits'] = pred_logits + outputs['pred_boxes'] = pred_boxes + if self.config.auxiliary_loss: + intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[5] + outputs_class = self.class_labels_classifier(intermediate) + outputs_coord = self.bbox_predictor(intermediate).sigmoid() + auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) + outputs['auxiliary_outputs'] = auxiliary_outputs + + loss_dict = criterion(outputs, labels) + weight_dict = criterion.weight_dict + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + + if not return_dict: + # to be verified + if auxiliary_outputs is not None: + output = (pred_logits, pred_boxes) + auxiliary_outputs + decoder_outputs + encoder_outputs + else: + output = (pred_logits, pred_boxes) + decoder_outputs + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return DetrObjectDetectionOutput( + loss=loss, + pred_logits=pred_logits, + pred_boxes=pred_boxes, + auxiliary_outputs=auxiliary_outputs, + #past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +# copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class SetCriterion(nn.Module): + """ This class computes the loss for DETRForObjectDetection. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, matcher, num_classes, weight_dict, eos_coef, losses): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals. + num_classes: number of object categories, omitting the special no-object category. + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category. + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + # (Niels): set log to False because we don't want to include accuracy in the modeling file + def loss_labels(self, outputs, targets, indices, num_boxes, log=False): + """Classification loss (NLL) + targets dicts must contain the key "class_labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes), + box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'auxiliary_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["class_labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + # (Niels): comment out function below, distributed training to be added + # if is_dist_avail_and_initialized(): + # torch.distributed.all_reduce(num_boxes) + # (Niels) in original implementation, num_boxes is divided by get_world_size() + num_boxes = torch.clamp(num_boxes, min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'auxiliary_outputs' in outputs: + for i, auxiliary_outputs in enumerate(outputs['auxiliary_outputs']): + indices = self.matcher(auxiliary_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + +# copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py +class MLP(nn.Module): + """ + Very simple multi-layer perceptron (also called FFN), used to predict the normalized + center coordinates, height and width of a bounding box w.r.t. an image. + + Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py + + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +# copied from https://github.com/facebookresearch/detr/blob/master/models/matcher.py +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): + """Creates the matcher. + Params: + class_cost: This is the relative weight of the classification error in the matching cost + bbox_cost: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + giou_cost: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.class_cost = class_cost + self.bbox_cost = bbox_cost + self.giou_cost = giou_cost + assert class_cost != 0 or bbox_cost != 0 or giou_cost != 0, "All costs of the Matcher can't be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching. + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs['pred_logits'].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs['pred_logits'].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs['pred_boxes'].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["class_labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + class_cost = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + bbox_cost = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost between boxes + giou_cost = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +# below: functies copied from https://github.com/facebookresearch/detr/blob/master/util/box_ops.py + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + The boxes should be in [x0, y0, x1, y1] format + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area \ No newline at end of file diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py new file mode 100644 index 00000000000000..c1d16256d1fd9b --- /dev/null +++ b/src/transformers/models/detr/tokenization_detr.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DETR.""" +from typing import List, Optional + +from tokenizers import ByteLevelBPETokenizer + +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {} + +PRETRAINED_VOCAB_FILES_MAP = {} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "facebook/detr-resnet-50": 1024, +} + +class DetrTokenizer(PreTrainedTokenizer): + """ + Construct a DETR tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + **kwargs + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) + + "Initialisation" + + @property + def vocab_size(self): + "Returns vocab size" + + def get_vocab(self): + "Returns vocab as a dict" + + def _tokenize(self, text): + """ Returns a tokenized string. """ + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + + def save_vocabulary(self, save_directory): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (:obj:`str`): + The directory in which to save the vocabulary. + + Returns: + :obj:`Tuple(str)`: Paths to the files saved. + """ + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A DETR sequence has the following format: + + - single sequence: `` X `` + - pair of sequences: `` A B `` + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + DETR does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + +class DetrTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" DETR tokenizer (backed by HuggingFace's `tokenizers` library). + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_prefix_space=False, + trim_offsets=True, + **kwargs + ): + super().__init__( + ByteLevelBPETokenizer( + vocab_file=vocab_file, + merges_file=merges_file, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + ), + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + **kwargs, + ) + self.add_prefix_space = add_prefix_space + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + DETR does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + diff --git a/src/transformers/models/detr/tokenization_detr_fast.py b/src/transformers/models/detr/tokenization_detr_fast.py new file mode 100644 index 00000000000000..9ef37a723fbda4 --- /dev/null +++ b/src/transformers/models/detr/tokenization_detr_fast.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DETR.""" +from typing import List, Optional + +from tokenizers import ByteLevelBPETokenizer + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_detr import DetrTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {} + +PRETRAINED_VOCAB_FILES_MAP = {} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "facebook/detr-resnet-50": 1024, +} + +class DetrTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" DETR tokenizer (backed by HuggingFace's `tokenizers` library). + + Args: + vocab_file (:obj:`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + slow_tokenizer_class = DetrTokenizer + + def __init__( + self, + vocab_file, + merges_file, + unk_token="<|endoftext|>", + bos_token="<|endoftext|>", + eos_token="<|endoftext|>", + add_prefix_space=False, + trim_offsets=True, + **kwargs + ): + super().__init__( + ByteLevelBPETokenizer( + vocab_file=vocab_file, + merges_file=merges_file, + add_prefix_space=add_prefix_space, + trim_offsets=trim_offsets, + ), + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + **kwargs, + ) + self.add_prefix_space = add_prefix_space + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + if token_ids_1 is None: + return output + + return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. + DETR does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + + diff --git a/tests/test_modeling_detr.py b/tests/test_modeling_detr.py new file mode 100644 index 00000000000000..a1e3e657f6cea9 --- /dev/null +++ b/tests/test_modeling_detr.py @@ -0,0 +1,371 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch DETR model. """ + + +import copy +import tempfile +import unittest + +from transformers import is_torch_available +from transformers.file_utils import cached_property +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_generation_utils import GenerationTesterMixin +from .test_modeling_common import ModelTesterMixin, ids_tensor + +from PIL import Image +import requests + + +if is_torch_available(): + import torch + import torchvision.transforms as T + + from transformers import ( + DetrConfig, + DetrModel, + DetrTokenizer, + DetrForObjectDetection, + ) + from transformers.models.detr.modeling_detr import ( + DetrDecoder, + DetrEncoder, + ) + + +def prepare_detr_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = input_ids.ne(config.pad_token_id) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) + return { + "input_ids": input_ids, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": attention_mask, + } + + +# @require_torch +# class DetrModelTester: +# def __init__( +# self, +# parent, +# batch_size=13, +# seq_length=7, +# is_training=True, +# use_labels=False, +# vocab_size=99, +# hidden_size=16, +# num_hidden_layers=2, +# num_attention_heads=4, +# intermediate_size=4, +# hidden_act="gelu", +# hidden_dropout_prob=0.1, +# attention_probs_dropout_prob=0.1, +# max_position_embeddings=20, +# eos_token_id=2, +# pad_token_id=1, +# bos_token_id=0, +# ): +# self.parent = parent +# self.batch_size = batch_size +# self.seq_length = seq_length +# self.is_training = is_training +# self.use_labels = use_labels +# self.vocab_size = vocab_size +# self.hidden_size = hidden_size +# self.num_hidden_layers = num_hidden_layers +# self.num_attention_heads = num_attention_heads +# self.intermediate_size = intermediate_size +# self.hidden_act = hidden_act +# self.hidden_dropout_prob = hidden_dropout_prob +# self.attention_probs_dropout_prob = attention_probs_dropout_prob +# self.max_position_embeddings = max_position_embeddings +# self.eos_token_id = eos_token_id +# self.pad_token_id = pad_token_id +# self.bos_token_id = bos_token_id + +# def prepare_config_and_inputs(self): +# input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) +# input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( +# 3, +# ) +# input_ids[:, -1] = self.eos_token_id # Eos Token + +# decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + +# config = DetrConfig( +# vocab_size=self.vocab_size, +# d_model=self.hidden_size, +# encoder_layers=self.num_hidden_layers, +# decoder_layers=self.num_hidden_layers, +# encoder_attention_heads=self.num_attention_heads, +# decoder_attention_heads=self.num_attention_heads, +# encoder_ffn_dim=self.intermediate_size, +# decoder_ffn_dim=self.intermediate_size, +# dropout=self.hidden_dropout_prob, +# attention_dropout=self.attention_probs_dropout_prob, +# max_position_embeddings=self.max_position_embeddings, +# eos_token_id=self.eos_token_id, +# bos_token_id=self.bos_token_id, +# pad_token_id=self.pad_token_id, +# ) +# inputs_dict = prepare_detr_inputs_dict(config, input_ids, decoder_input_ids) +# return config, inputs_dict + +# def prepare_config_and_inputs_for_common(self): +# config, inputs_dict = self.prepare_config_and_inputs() +# return config, inputs_dict + +# def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): +# model = DetrModel(config=config).get_decoder().to(torch_device).eval() +# input_ids = inputs_dict["input_ids"] +# attention_mask = inputs_dict["attention_mask"] + +# # first forward pass +# outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) + +# output, past_key_values = outputs.to_tuple() + +# # create hypothetical multiple next token and extent to next_input_ids +# next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) +# next_attn_mask = ids_tensor((self.batch_size, 3), 2) + +# # append to next input_ids and +# next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) +# next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) + +# output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] +# output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)["last_hidden_state"] + +# # select random slice +# random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() +# output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() +# output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + +# self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + +# # test that outputs are equal for slice +# self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) + +# def check_encoder_decoder_model_standalone(self, config, inputs_dict): +# model = DetrModel(config=config).to(torch_device).eval() +# outputs = model(**inputs_dict) + +# encoder_last_hidden_state = outputs.encoder_last_hidden_state +# last_hidden_state = outputs.last_hidden_state + +# with tempfile.TemporaryDirectory() as tmpdirname: +# encoder = model.get_encoder() +# encoder.save_pretrained(tmpdirname) +# encoder = DetrEncoder.from_pretrained(tmpdirname).to(torch_device) + +# encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ +# 0 +# ] + +# self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) + +# with tempfile.TemporaryDirectory() as tmpdirname: +# decoder = model.get_decoder() +# decoder.save_pretrained(tmpdirname) +# decoder = DetrDecoder.from_pretrained(tmpdirname).to(torch_device) + +# last_hidden_state_2 = decoder( +# input_ids=inputs_dict["decoder_input_ids"], +# attention_mask=inputs_dict["decoder_attention_mask"], +# encoder_hidden_states=encoder_last_hidden_state, +# encoder_attention_mask=inputs_dict["attention_mask"], +# )[0] + +# self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) + + +# @require_torch +# class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): +# all_model_classes = ( +# (DetrModel, DetrForObjectDetection,) +# if is_torch_available() +# else () +# ) +# #all_generative_model_classes = (DetrForConditionalGeneration,) if is_torch_available() else () +# is_encoder_decoder = True +# test_pruning = False +# test_head_masking = False +# test_missing_keys = False + +# def setUp(self): +# self.model_tester = DetrModelTester(self) +# self.config_tester = ConfigTester(self, config_class=DetrConfig) + +# def test_config(self): +# self.config_tester.run_common_tests() + +# def test_save_load_strict(self): +# config, inputs_dict = self.model_tester.prepare_config_and_inputs() +# for model_class in self.all_model_classes: +# model = model_class(config) + +# with tempfile.TemporaryDirectory() as tmpdirname: +# model.save_pretrained(tmpdirname) +# model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) +# self.assertEqual(info["missing_keys"], []) + +# def test_decoder_model_past_with_large_inputs(self): +# config_and_inputs = self.model_tester.prepare_config_and_inputs() +# self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + +# def test_encoder_decoder_model_standalone(self): +# config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() +# self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) + +# # DetrForSequenceClassification does not support inputs_embeds +# def test_inputs_embeds(self): +# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + +# for model_class in (DetrModel, DetrForConditionalGeneration, DetrForQuestionAnswering): +# model = model_class(config) +# model.to(torch_device) +# model.eval() + +# inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + +# if not self.is_encoder_decoder: +# input_ids = inputs["input_ids"] +# del inputs["input_ids"] +# else: +# encoder_input_ids = inputs["input_ids"] +# decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) +# del inputs["input_ids"] +# inputs.pop("decoder_input_ids", None) + +# wte = model.get_input_embeddings() +# if not self.is_encoder_decoder: +# inputs["inputs_embeds"] = wte(input_ids) +# else: +# inputs["inputs_embeds"] = wte(encoder_input_ids) +# inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + +# with torch.no_grad(): +# model(**inputs)[0] + +# def test_generate_fp16(self): +# config, input_dict = self.model_tester.prepare_config_and_inputs() +# input_ids = input_dict["input_ids"] +# attention_mask = input_ids.ne(1).to(torch_device) +# model = DetrForConditionalGeneration(config).eval().to(torch_device) +# if torch_device == "cuda": +# model.half() +# model.generate(input_ids, attention_mask=attention_mask) +# model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + + +def assert_tensors_close(a, b, atol=1e-12, prefix=""): + """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if torch.allclose(a, b, atol=atol): + return True + raise + except Exception: + pct_different = (torch.gt((a - b).abs(), atol)).float().mean().item() + if a.numel() > 100: + msg = f"tensor values are {pct_different:.1%} percent different." + else: + msg = f"{a} != {b}" + if prefix: + msg = prefix + ": " + msg + raise AssertionError(msg) + + +def _long_tensor(tok_lst): + return torch.tensor(tok_lst, dtype=torch.long, device=torch_device) + + +TOLERANCE = 1e-4 + + +# We will verify our outputs against the original implementation on an image of cute cats +def prepare_img(): + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + im = Image.open(requests.get(url, stream=True).raw) + + # standard PyTorch mean-std input image normalization + transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0) + + return img + + +@require_torch +@slow +class DetrModelIntegrationTests(unittest.TestCase): + # @cached_property + # def default_tokenizer(self): + # return DetrTokenizer.from_pretrained('facebook/detr-resnet-50') + + def test_inference_no_head(self): + model = DetrModel.from_pretrained('nielsr/detr-resnet-50').to(torch_device) + model.eval() + img = prepare_img().to(torch_device) + + with torch.no_grad(): + outputs = model(img) + + expected_shape = torch.Size((1, 100, 256)) + assert outputs.last_hidden_state.shape == expected_shape + expected_slice = torch.tensor([[0.0616, -0.5146, -0.4032], + [-0.7629, -0.4934, -1.7153], + [-0.4768, -0.6403, -0.7826]]).to(torch_device) + self.assertTrue(torch.allclose(outputs.last_hidden_state[0,:3,:3], expected_slice, atol=1e-4)) + + + def test_inference_object_detection_head(self): + model = DetrForObjectDetection.from_pretrained('nielsr/detr-resnet-50').to(torch_device) + model.eval() + img = prepare_img().to(torch_device) + + with torch.no_grad(): + outputs = model(img) + + expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels + 1)) + self.assertEqual(outputs.pred_logits.shape, expected_shape_logits) + expected_slice_logits = torch.tensor([[-19.1194, -0.0893, -11.0154], + [-17.3640, -1.8035, -14.0219], + [-20.0461, -0.5837, -11.1060]]).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_logits[0,:3,:3], expected_slice_logits, atol=1e-4)) + + expected_shape_boxes = torch.Size((1, model.config.num_queries, 4)) + self.assertEqual(outputs.pred_boxes.shape, expected_shape_boxes) + expected_slice_boxes = torch.tensor([[0.4433, 0.5302, 0.8853], + [0.5494, 0.2517, 0.0529], + [0.4998, 0.5360, 0.9956]]).to(torch_device) + self.assertTrue(torch.allclose(outputs.pred_boxes[0,:3,:3], expected_slice_boxes, atol=1e-4)) \ No newline at end of file diff --git a/utils/check_repo.py b/utils/check_repo.py index c8881baa651d9b..9b6e574bdbe144 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -30,6 +30,8 @@ # Being in this list is an exception and should **not** be the rule. IGNORE_NON_TESTED = [ # models to ignore for not tested +"DetrEncoder", # Building part of bigger (tested) model. + "DetrDecoder", # Building part of bigger (tested) model. "LEDEncoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model. "BartDecoderWrapper", # Building part of bigger (tested) model. @@ -75,6 +77,8 @@ # should **not** be the rule. IGNORE_NON_AUTO_CONFIGURED = [ # models to ignore for model xxx mapping +"DetrEncoder", + "DetrDecoder", "LEDEncoder", "LEDDecoder", "BartDecoder", From 75936a41ee8dd2204098f1a0b9100eb139f52e26 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 5 Feb 2021 18:02:40 +0100 Subject: [PATCH 02/20] First draft of DetrTokenizer --- .gitattributes..txt | 3 + docs/source/model_doc/detr.rst | 82 ++- .../models/detr/configuration_detr.py | 3 +- src/transformers/models/detr/modeling_detr.py | 14 +- .../models/detr/tokenization_detr.py | 589 ++++++++++++------ tests/test_modeling_detr.py | 10 - tests/test_tokenization_detr.py | 57 ++ 7 files changed, 522 insertions(+), 236 deletions(-) create mode 100644 .gitattributes..txt create mode 100644 tests/test_tokenization_detr.py diff --git a/.gitattributes..txt b/.gitattributes..txt new file mode 100644 index 00000000000000..800966870fb8cc --- /dev/null +++ b/.gitattributes..txt @@ -0,0 +1,3 @@ +*.py eol=lf +*.rst eol=lf +*.md eol=lf \ No newline at end of file diff --git a/docs/source/model_doc/detr.rst b/docs/source/model_doc/detr.rst index 2fe08238844822..9bcb9356b0f3f0 100644 --- a/docs/source/model_doc/detr.rst +++ b/docs/source/model_doc/detr.rst @@ -16,16 +16,62 @@ DETR Overview ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The DETR model was proposed in ` -<>`__ by . +The DETR model was proposed in `End-to-End Object Detection with Transformers +`__ by Nicolas Carion, Francisco Massa, Gabriel Synnaeve, Nicolas Usunier, Alexander Kirillov +and Sergey Zagoruyko. DETR consists of a convolutional backbone followed by an encoder-decoder Transformer which can be trained +end-to-end for object detection. It greatly simplifies a lot of the complexity of models like Faster-R-CNN and Mask-R-CNN, which +use things like region proposals, non-maximum suppression procedure and anchor generation. Moreover, DETR can also be naturally extended +to perform panoptic segmentation, by simply adding a mask head on top of the decoder outputs. The abstract from the paper is the following: -** +*We present a new method that views object detection as a direct set prediction problem. Our approach streamlines the detection pipeline, +effectively removing the need for many hand-designed components like a non-maximum suppression procedure or anchor generation that explicitly +encode our prior knowledge about the task. The main ingredients of the new framework, called DEtection TRansformer or DETR, are a set-based +global loss that forces unique predictions via bipartite matching, and a transformer encoder-decoder architecture. Given a fixed small set of +learned object queries, DETR reasons about the relations of the objects and the global image context to directly output the final set of predictions +in parallel. The new model is conceptually simple and does not require a specialized library, unlike many other modern detectors. DETR demonstrates +accuracy and run-time performance on par with the well-established and highly-optimized Faster RCNN baseline on the challenging COCO object detection +dataset. Moreover, DETR can be easily generalized to produce panoptic segmentation in a unified manner. We show that it significantly outperforms +competitive baselines.* + +The original code can be found `here `__. + +Here's a TLDR explaining how the model works: + +First, an image is sent through a pre-trained convolutional backbone (in the paper, the authors use ResNet-50/ResNet-101). Let's assume we also add a +batch dimension. This means that the input to the backbone is a tensor of shape :obj:`(1, 3, height, width)`, assuming the image has 3 color channels (RGB). +The CNN backbone outputs a new lower-resolution feature map, typically of shape :obj:`(1, 2048, height/32, width/32)`. This is then projected to match +the hidden dimension of the Transformer of DETR, which is :obj:`256` by default, using a :obj:`nn.Conv2D` layer. So now, we have a tensor of shape +:obj:`(1, 256, height/32, width/32).` Next, the image is flattened and transposed to obtain a tensor of shape :obj:`(batch_size, seq_len, d_model)` = +:obj:`(1, width/32*height/32, 256)`. So a difference with NLP models is that the sequence length is actually longer than usual, but with a smaller +:obj:`d_model` (which in NLP is typically 768 or higher). + +Next, this is sent through the encoder, outputting :obj:`encoder_hidden_states` of the same shape (you can consider these as image features). Next, so-called +**object queries** are sent through the decoder. This is a tensor of shape :obj:`(batch_size, num_queries, d_model)`, with `num_queries` typically set +to 100 and is initialized with zeros. Each object query looks for a particular object in the image. Next, the decoder updates these object queries through +multiple self-attention and encoder-decoder attention layers to output :obj:`decoder_hidden_states` of the same shape: :obj:`(batch_size, num_queries, d_model)`. +Next, two heads are added on top for object detection: a linear layer for classifying each object query into one of the objects or "no object", and a MLP +to predict bounding boxes for each query. + +The model is trained using a **bipartite matching loss**: so what we actually do is compare the predicted classes + bounding boxes of each of the +N = 100 object queries to the ground truth annotations, padded up to the same length N (so if an image only contains 4 objects, 96 annotations will +just have a "no object" as class and "no bounding box" as bounding box). The `Hungarian matching algorithm `__ is used to create a one-to-one mapping of +each of the N queries to each of the N annotations. Next, standard cross-entropy and L1 bounding box losses are used to optimize the parameters of +the model. Tips: - +- DETR uses so-called **object queries** to detect objects in an image. The number of queries determines the maximum number of objects that + can be detected in a single image, and is set to 100 by default (see parameter :obj:`num_queries` of :class:`~transformers.DetrConfig`). +- The decoder of DETR updates the query embeddings in parallel. This is different from language models like GPT-2, which use autoregressive decoding + instead of parallel. Hence, no causal attention mask is used. +- DETR adds position embeddings to the hidden states at each self-attention and cross-attention layer before projecting to queries and keys. + For the position embeddings of the image, one can choose between fixed sinusoidal or learned absolute position embeddings. By default, + the parameter :obj:`position_embedding_type` of :class:`~transformers.DetrConfig` is set to :obj:`sine`. +- During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help the model output the correct + number of objects of each class. If you set the parameter :obj:`auxiliary_loss` of :class:`~transformers.DetrConfig` to :obj:`True`, then prediction + feedforward neural networks and Hungarian losses are added after each decoder layer (with the FFNs sharing parameters). DetrConfig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -38,16 +84,7 @@ DetrTokenizer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.DetrTokenizer - :members: build_inputs_with_special_tokens, get_special_tokens_mask, - create_token_type_ids_from_sequences, save_vocabulary - - -DetrTokenizerFast -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.DetrTokenizerFast - :members: build_inputs_with_special_tokens, get_special_tokens_mask, - create_token_type_ids_from_sequences, save_vocabulary + :members: __call__ DetrModel @@ -57,24 +94,11 @@ DetrModel :members: forward -DetrForConditionalGeneration +DetrForObjectDetection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.DetrForConditionalGeneration +.. autoclass:: transformers.DetrForObjectDetection :members: forward -DetrForSequenceClassification -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.DetrForSequenceClassification - :members: forward - - -DetrForQuestionAnswering -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.DetrForQuestionAnswering - :members: forward - diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index 6b9c76a476c760..be50649ca601e4 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -108,7 +108,8 @@ class DetrConfig(PretrainedConfig): Relative weight of the generalized IoU loss in the object detection loss. eos_coefficient (:obj:`float`, `optional`, defaults to 0.1): Relative classification weight of the 'no-object' class in the object detection loss. - Example:: + + Examples:: >>> from transformers import DetrModel, DetrConfig diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index ac5c7b19785b6c..f97455a7a89fd8 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -18,7 +18,7 @@ import math import random from dataclasses import dataclass -from typing import Optional, Tuple, Dict, List +from typing import Optional, Tuple, Dict, List, Union from scipy.optimize import linear_sum_assignment import torch @@ -186,6 +186,7 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + print('Max size', max_size) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape @@ -289,6 +290,17 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.num_channels = num_channels + # def forward_new(self, pixel_values: Union[torch.Tensor, list[torch.Tensor]], pixel_mask: Optional[torch.Tensor]): + # xs = self.body(pixel_values) + # out: Dict[str, NestedTensor] = {} + # for name, x in xs.items(): + # m = pixel_mask + # assert m is not None + # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + # out[name] = NestedTensor(x, mask) + # return out + + # this one should be removed in the future def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) out: Dict[str, NestedTensor] = {} diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index c1d16256d1fd9b..cbab0d7b1e62d5 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright Facebook AI Research and The HuggingFace Inc. team. All rights reserved. +# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,239 +12,438 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for DETR.""" -from typing import List, Optional - -from tokenizers import ByteLevelBPETokenizer - -from ...tokenization_utils import AddedToken, PreTrainedTokenizer -from ...tokenization_utils_fast import PreTrainedTokenizerFast +"""Tokenization class for DETR.""" + +import json +import os +from itertools import groupby +from typing import Dict, List, Optional, Tuple, Union +import random + +import numpy as np +import torch +from torch import Tensor +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms as T +import PIL + +from ...file_utils import add_end_docstrings +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TensorType from ...utils import logging logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {} - -PRETRAINED_VOCAB_FILES_MAP = {} - -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - "facebook/detr-resnet-50": 1024, -} -class DetrTokenizer(PreTrainedTokenizer): +DETR_KWARGS_DOCSTRING = r""" + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to print more information and warnings. +""" + +## BELOW: utilities copied from +# https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/misc.py + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): """ - Construct a DETR tokenizer. Based on byte-level Byte-Pair-Encoding. - - Args: - vocab_file (:obj:`str`): - Path to the vocabulary file. + Data type that handles different types of inputs (either list of images or list of sequences), + and computes the padded output (with masking). """ + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: Union[List[Tensor], torch.Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('Not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +## Image + target transformations for object detection +## Taken from https://github.com/facebookresearch/detr/blob/master/datasets/transforms.py + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - def __init__( - self, - vocab_file, - unk_token="<|endoftext|>", - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - **kwargs - ): - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) - - "Initialisation" +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std - @property - def vocab_size(self): - "Returns vocab size" + def __call__(self, image, target=None): + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xyxy_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target - def get_vocab(self): - "Returns vocab as a dict" - def _tokenize(self, text): - """ Returns a tokenized string. """ +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms - def _convert_token_to_id(self, token): - """ Converts a token (str) in an id using the vocab. """ + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - def convert_tokens_to_string(self, tokens): - """ Converts a sequence of tokens (string) in a single string. """ +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target - def save_vocabulary(self, save_directory): - """ - Save the vocabulary and special tokens file to a directory. - Args: - save_directory (:obj:`str`): - The directory in which to save the vocabulary. - - Returns: - :obj:`Tuple(str)`: Paths to the files saved. - """ - - def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Build model inputs from a sequence or a pair of sequence for sequence classification tasks - by concatenating and adding special tokens. - A DETR sequence has the following format: - - - single sequence: `` X `` - - pair of sequences: `` A B `` - - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs to which the special tokens will be added. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. - - Returns: - :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. - """ - if token_ids_1 is None: - return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] - cls = [self.cls_token_id] - sep = [self.sep_token_id] - return cls + token_ids_0 + sep + sep + token_ids_1 + sep - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer ``prepare_for_model`` method. - - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - if token_ids_1 is not None: - raise ValueError( - "You should not supply a second sequence if the provided sequence of " - "ids is already formatted with special tokens for the model." - ) - return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) - - if token_ids_1 is None: - return [1] + ([0] * len(token_ids_0)) + [1] - return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - DETR does not make use of token type ids, therefore a list of zeros is returned. - - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. - - Returns: - :obj:`List[int]`: List of zeros. - """ - sep = [self.sep_token_id] - cls = [self.cls_token_id] - - if token_ids_1 is None: - return len(cls + token_ids_0 + sep) * [0] - return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] - - def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): - add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) - if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): - text = " " + text - return (text, kwargs) - -class DetrTokenizerFast(PreTrainedTokenizerFast): +class DetrTokenizer(PreTrainedTokenizer): """ - Construct a "fast" DETR tokenizer (backed by HuggingFace's `tokenizers` library). + Constructs a DETR tokenizer. + + This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains some of the main methods. + Users should refer to the superclass for more information regarding such methods. Args: - vocab_file (:obj:`str`): - Path to the vocabulary file. + bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The beginning of sentence token. + eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The end of sentence token. + unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): + The token used for padding, for example when batching sequences of different lengths. + word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`): + The token used for defining the end of a word. + **kwargs + Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer` """ - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] + model_input_names = ["input_values"] def __init__( - self, - vocab_file, - merges_file, - unk_token="<|endoftext|>", - bos_token="<|endoftext|>", - eos_token="<|endoftext|>", - add_prefix_space=False, - trim_offsets=True, - **kwargs + self, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + word_delimiter_token="|", + do_lower_case=False, + **kwargs ): super().__init__( - ByteLevelBPETokenizer( - vocab_file=vocab_file, - merges_file=merges_file, - add_prefix_space=add_prefix_space, - trim_offsets=trim_offsets, - ), + unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, - unk_token=unk_token, + pad_token=pad_token, + do_lower_case=do_lower_case, + word_delimiter_token=word_delimiter_token, **kwargs, ) - self.add_prefix_space = add_prefix_space + self._word_delimiter_token = word_delimiter_token + self.do_lower_case = do_lower_case + + @add_end_docstrings(DETR_KWARGS_DOCSTRING) + def __call__( + self, + images: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor]], + annotations: Optional[Union[Dict, List[Dict]]] = None, + padding: Union[bool, str] = True, + return_mask: Union[bool, str] = True, + resize: Optional[bool] = True, + size: Optional[int] = 800, + max_size: Optional[int] = 1333, + normalize: Optional[bool] = True, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to prepare for the model one or several image(s). This includes resizing, normalization and padding up to + the largest image in a batch while creating a pixel mask for each image indicating which pixels are real and which are padding. + + Args: + images (:obj:`PIL.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`, + :obj:`List[PIL.Image]`): + The image or batch of images to be prepared. Each image can be a PIL image, numpy array or a Torch tensor. + annotations (:obj:`Dict`, :obj:`List[Dict]`): + The annotations as either a single Python dictionary or a batch of Python dictionaries. Keys include "size", "area", + "boxes" and "masks". + resize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to resize images to a certain :obj:`size`. + size (:obj:`int`, `optional`, defaults to :obj:`800`): + Resize the input image to the given size. Only has an effect if :obj:`resize` is set to :obj:`True`. + max_size (:obj:`int`, `optional`, defaults to :obj:`1333`): + Resize up to a certain max size. In COCO, the authors of DETR used a :obj:`max_size` of 1333. + normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to apply standard ImageNet mean/std normalization of images. + """ - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - output = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] - if token_ids_1 is None: - return output + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (PIL.Image.Image, np.ndarray, torch.Tensor))) + ) - return output + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] + # make images a list of PIL images no matter what + if is_batched: + if isinstance(images[0], np.array): + images = [Image.fromarray(image) for image in images] + if annotations is not None: + assert len(images) == len(annotations) + + elif isinstance(images[0], torch.Tensor): + images = [T.ToPILImage()(image).convert("RGB") for image in images] + if annotations is not None: + assert len(images) == len(annotations) + else: + if isinstance(images, PIL.Image.Image): + images = [images] + annotations = [annotations] + + # next up: apply image transformations (resizing + normalization) + transformations = [] + if resize and size is not None: + transformations.append(RandomResize(sizes=[size], max_size=max_size)) + if normalize: + normalize = Compose([ + ToTensor(), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + transformations.append(normalize) + transforms = Compose(transformations) + transformed_images = [] + for image, annotation in zip(images, annotations): + image, annotation = transforms(image, annotation) + transformed_images.append(image) + + # next, create NestedTensor which takes care of padding up to biggest image + samples = nested_tensor_from_tensor_list(transformed_images) + + # return as dict + data = {"pixel_values": samples.tensors, 'pixel_mask': samples.mask} + encoded_inputs = BatchEncoding(data=data) + + return encoded_inputs + + @property + def vocab_size(self) -> int: + return len(self.decoder) + def get_vocab(self) -> Dict: + return dict(self.encoder, **self.added_tokens_encoder) - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """ + Converts a connectionist-temporal-classification (CTC) output tokens into a single string. """ - Create a mask from the two sequences passed to be used in a sequence-pair classification task. - DETR does not make use of token type ids, therefore a list of zeros is returned. + # group same tokens into non-repeating tokens in CTC style decoding + grouped_tokens = [token_group[0] for token_group in groupby(tokens)] - Args: - token_ids_0 (:obj:`List[int]`): - List of IDs. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. + # filter self.pad_token which is used as CTC-blank token + filtered_tokens = list(filter(lambda token: token != self.pad_token, grouped_tokens)) - Returns: - :obj:`List[int]`: List of zeros. + # replace delimiter token + string = "".join([" " if token == self.word_delimiter_token else token for token in filtered_tokens]).strip() + + if self.do_lower_case: + string = string.lower() + return string + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = True, + ) -> str: + """ + special _decode function is needed for DETRTokenizer because added tokens should be treated exactly the + same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on + the whole token list and not individually on added tokens """ - sep = [self.sep_token_id] - cls = [self.cls_token_id] + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) - if token_ids_1 is None: - return len(cls + token_ids_0 + sep) * [0] - return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + result = [] + for token in filtered_tokens: + if skip_special_tokens and token in self.all_special_ids: + continue + result.append(token) + text = self.convert_tokens_to_string(result) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text \ No newline at end of file diff --git a/tests/test_modeling_detr.py b/tests/test_modeling_detr.py index a1e3e657f6cea9..563d703d7a572e 100644 --- a/tests/test_modeling_detr.py +++ b/tests/test_modeling_detr.py @@ -270,16 +270,6 @@ def prepare_detr_inputs_dict( # with torch.no_grad(): # model(**inputs)[0] -# def test_generate_fp16(self): -# config, input_dict = self.model_tester.prepare_config_and_inputs() -# input_ids = input_dict["input_ids"] -# attention_mask = input_ids.ne(1).to(torch_device) -# model = DetrForConditionalGeneration(config).eval().to(torch_device) -# if torch_device == "cuda": -# model.half() -# model.generate(input_ids, attention_mask=attention_mask) -# model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) - def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/test_tokenization_detr.py b/tests/test_tokenization_detr.py new file mode 100644 index 00000000000000..3b8d48a23160a8 --- /dev/null +++ b/tests/test_tokenization_detr.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Wav2Vec2 tokenizer.""" + +import inspect +import json +import os +import random +import shutil +import tempfile +import unittest + +import torch +import numpy as np + +from PIL import Image +import requests + +from transformers.models.detr.tokenization_detr import DetrTokenizer + +class DetrTokenizerTest(unittest.TestCase): + tokenizer_class = DetrTokenizer + + def setUp(self): + super().setUp() + + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + self.img = Image.open(requests.get(url, stream=True).raw) + + def get_tokenizer(self, **kwars): + return DetrTokenizer() + + def test_tokenizer_no_resize(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.img, resize=False) + + self.assertEqual(encoding["pixel_values"].shape, (1,3,480,640)) + self.assertEqual(encoding["pixel_mask"].shape, (1,480,640)) + + def test_tokenizer(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.img) + + self.assertEqual(encoding["pixel_values"].shape, (1,3,800,1066)) + self.assertEqual(encoding["pixel_mask"].shape, (1,800,1066)) \ No newline at end of file From e41808a5027bbcc270207c3f1f47daf5b19e1b54 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 5 Feb 2021 18:06:52 +0100 Subject: [PATCH 03/20] Update gitattributes file --- .gitattributes.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .gitattributes.txt diff --git a/.gitattributes.txt b/.gitattributes.txt new file mode 100644 index 00000000000000..800966870fb8cc --- /dev/null +++ b/.gitattributes.txt @@ -0,0 +1,3 @@ +*.py eol=lf +*.rst eol=lf +*.md eol=lf \ No newline at end of file From 7a5617100a8efe7174300f2945b0ad74264104ed Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 5 Feb 2021 18:11:23 +0100 Subject: [PATCH 04/20] Introduce end-of-line normalization --- .gitattributes..txt | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .gitattributes..txt diff --git a/.gitattributes..txt b/.gitattributes..txt deleted file mode 100644 index 800966870fb8cc..00000000000000 --- a/.gitattributes..txt +++ /dev/null @@ -1,3 +0,0 @@ -*.py eol=lf -*.rst eol=lf -*.md eol=lf \ No newline at end of file From 5f50e09b7966553136f7e5f5b88f6cd2462fc925 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 5 Feb 2021 18:18:54 +0100 Subject: [PATCH 05/20] Another try to solve the line endings issue --- .gitattributes.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.gitattributes.txt b/.gitattributes.txt index 800966870fb8cc..2125666142eb66 100644 --- a/.gitattributes.txt +++ b/.gitattributes.txt @@ -1,3 +1 @@ -*.py eol=lf -*.rst eol=lf -*.md eol=lf \ No newline at end of file +* text=auto \ No newline at end of file From 62e12ce5c0fab66fd79afb339f01e7dde5ef4215 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 5 Feb 2021 18:22:23 +0100 Subject: [PATCH 06/20] Added more lines to the gitattributes file --- .gitattributes.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitattributes.txt b/.gitattributes.txt index 2125666142eb66..06f7128826c165 100644 --- a/.gitattributes.txt +++ b/.gitattributes.txt @@ -1 +1,6 @@ -* text=auto \ No newline at end of file +# Set the default behavior, in case people don't have core.autocrlf set +* text=auto + +*.py eol=lf +*.rst eol=lf +*.md eol=lf \ No newline at end of file From fa0a155ccc93cc0fe5cc936ab8014737b527c296 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Fri, 5 Feb 2021 19:33:21 +0100 Subject: [PATCH 07/20] Improve docs --- docs/source/model_doc/detr.rst | 5 +++-- src/transformers/models/detr/tokenization_detr.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/model_doc/detr.rst b/docs/source/model_doc/detr.rst index 9bcb9356b0f3f0..d23af0df4a3578 100644 --- a/docs/source/model_doc/detr.rst +++ b/docs/source/model_doc/detr.rst @@ -48,7 +48,7 @@ the hidden dimension of the Transformer of DETR, which is :obj:`256` by default, :obj:`d_model` (which in NLP is typically 768 or higher). Next, this is sent through the encoder, outputting :obj:`encoder_hidden_states` of the same shape (you can consider these as image features). Next, so-called -**object queries** are sent through the decoder. This is a tensor of shape :obj:`(batch_size, num_queries, d_model)`, with `num_queries` typically set +**object queries** are sent through the decoder. This is a tensor of shape :obj:`(batch_size, num_queries, d_model)`, with :obj:`num_queries` typically set to 100 and is initialized with zeros. Each object query looks for a particular object in the image. Next, the decoder updates these object queries through multiple self-attention and encoder-decoder attention layers to output :obj:`decoder_hidden_states` of the same shape: :obj:`(batch_size, num_queries, d_model)`. Next, two heads are added on top for object detection: a linear layer for classifying each object query into one of the objects or "no object", and a MLP @@ -64,11 +64,12 @@ Tips: - DETR uses so-called **object queries** to detect objects in an image. The number of queries determines the maximum number of objects that can be detected in a single image, and is set to 100 by default (see parameter :obj:`num_queries` of :class:`~transformers.DetrConfig`). + Note that it's good to have some slack (in COCO, the authors used 100, while the maximum number of objects in a COCO image is ~70). - The decoder of DETR updates the query embeddings in parallel. This is different from language models like GPT-2, which use autoregressive decoding instead of parallel. Hence, no causal attention mask is used. - DETR adds position embeddings to the hidden states at each self-attention and cross-attention layer before projecting to queries and keys. For the position embeddings of the image, one can choose between fixed sinusoidal or learned absolute position embeddings. By default, - the parameter :obj:`position_embedding_type` of :class:`~transformers.DetrConfig` is set to :obj:`sine`. + the parameter :obj:`position_embedding_type` of :class:`~transformers.DetrConfig` is set to :obj:`"sine"`. - During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help the model output the correct number of objects of each class. If you set the parameter :obj:`auxiliary_loss` of :class:`~transformers.DetrConfig` to :obj:`True`, then prediction feedforward neural networks and Hungarian losses are added after each decoder layer (with the FFNs sharing parameters). diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index cbab0d7b1e62d5..08de40155340e5 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -346,7 +346,7 @@ def __call__( size (:obj:`int`, `optional`, defaults to :obj:`800`): Resize the input image to the given size. Only has an effect if :obj:`resize` is set to :obj:`True`. max_size (:obj:`int`, `optional`, defaults to :obj:`1333`): - Resize up to a certain max size. In COCO, the authors of DETR used a :obj:`max_size` of 1333. + The largest size an image dimension can have (otherwise it's capped). In COCO, the authors of DETR used a :obj:`max_size` of 1333. normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to apply standard ImageNet mean/std normalization of images. """ From fd6dd758a891ade64c01fe8e04db148d9ad62eab Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Sat, 6 Feb 2021 21:39:45 +0100 Subject: [PATCH 08/20] More improvements to DetrTokenizer --- .../models/detr/tokenization_detr.py | 127 ++++++++++++++---- tests/test_tokenization_detr.py | 92 ++++++++++++- 2 files changed, 194 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index 08de40155340e5..f79abd6f1932e1 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -107,7 +107,8 @@ def __repr__(self): def nested_tensor_from_tensor_list(tensor_list: Union[List[Tensor], torch.Tensor]): - # TODO make this more general + # TODO make this more n + print(tensor_list[0].shape) if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX @@ -165,6 +166,75 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen ## Image + target transformations for object detection ## Taken from https://github.com/facebookresearch/detr/blob/master/datasets/transforms.py +# this: extra transform, based on https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/datasets/coco.py#L21 +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False): + self.return_masks = return_masks + + def __call__(self, image, target): + w, h = image.size + + #image_id = target["image_id"] + #image_id = torch.tensor([image_id]) + + #anno = target["annotations"] + + #anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + #boxes = [obj["bbox"] for obj in anno] + if target is not None: + boxes = target["boxes"] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + #classes = [obj["category_id"] for obj in anno] + #classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + # keypoints = None + # if anno and "keypoints" in anno[0]: + # keypoints = [obj["keypoints"] for obj in anno] + # keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + # num_keypoints = keypoints.shape[0] + # if num_keypoints: + # keypoints = keypoints.view(num_keypoints, -1, 3) + + # keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + # boxes = boxes[keep] + # classes = classes[keep] + # if self.return_masks: + # masks = masks[keep] + # if keypoints is not None: + # keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + # target["labels"] = classes + # if self.return_masks: + # target["masks"] = masks + # target["image_id"] = image_id + # if keypoints is not None: + # target["keypoints"] = keypoints + + # # for conversion to coco api + # area = torch.tensor([obj["area"] for obj in anno]) + # iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + # target["area"] = area[keep] + # target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(h), int(w)]) + target["size"] = torch.as_tensor([int(h), int(w)]) + + return image, target + return image, None + + def resize(image, target, size, max_size=None): # size can be min_size (scalar) or (w, h) tuple @@ -224,15 +294,20 @@ def get_size(image_size, size, max_size=None): return rescaled_image, target -class RandomResize(object): - def __init__(self, sizes, max_size=None): - assert isinstance(sizes, (list, tuple)) - self.sizes = sizes +class Resize(object): + def __init__(self, size, max_size=None): + self.size = size self.max_size = max_size def __call__(self, img, target=None): - size = random.choice(self.sizes) - return resize(img, target, size, self.max_size) + return resize(img, target, self.size, self.max_size) + +# copied from https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/box_ops.py +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) class Normalize(object): @@ -258,7 +333,7 @@ class Compose(object): def __init__(self, transforms): self.transforms = transforms - def __call__(self, image, target): + def __call__(self, image, target=None): for t in self.transforms: image, target = t(image, target) return image, target @@ -358,24 +433,22 @@ def __call__( # make images a list of PIL images no matter what if is_batched: - if isinstance(images[0], np.array): + if isinstance(images[0], np.ndarray): images = [Image.fromarray(image) for image in images] - if annotations is not None: - assert len(images) == len(annotations) - elif isinstance(images[0], torch.Tensor): images = [T.ToPILImage()(image).convert("RGB") for image in images] - if annotations is not None: - assert len(images) == len(annotations) + if annotations is not None: + assert len(images) == len(annotations) else: if isinstance(images, PIL.Image.Image): images = [images] + if annotations is not None: annotations = [annotations] - # next up: apply image transformations (resizing + normalization) - transformations = [] + # next: apply transformations (resizing + normalization) to both images and annotations + transformations = [ConvertCocoPolysToMask(),] if resize and size is not None: - transformations.append(RandomResize(sizes=[size], max_size=max_size)) + transformations.append(Resize(size=size, max_size=max_size)) if normalize: normalize = Compose([ ToTensor(), @@ -383,16 +456,22 @@ def __call__( ]) transformations.append(normalize) transforms = Compose(transformations) - transformed_images = [] - for image, annotation in zip(images, annotations): - image, annotation = transforms(image, annotation) - transformed_images.append(image) - # next, create NestedTensor which takes care of padding up to biggest image - samples = nested_tensor_from_tensor_list(transformed_images) + if annotations is not None: + transformed_images = [] + for image, annotation in zip(images, annotations): + image, annotation = transforms(image, annotation) + transformed_images.append(image) + else: + transformed_images = [transforms(image, None) for image in images] + + # next, create NestedTensor which takes care of padding the pixels up to biggest image + # we don't need the transformed targets for this + samples = nested_tensor_from_tensor_list([x[0] for x in transformed_images]) # return as dict - data = {"pixel_values": samples.tensors, 'pixel_mask': samples.mask} + data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask, + 'labels': [x[1] for x in transformed_images] if annotations is not None else None} encoded_inputs = BatchEncoding(data=data) return encoded_inputs diff --git a/tests/test_tokenization_detr.py b/tests/test_tokenization_detr.py index 3b8d48a23160a8..b30b0ab4cf56a6 100644 --- a/tests/test_tokenization_detr.py +++ b/tests/test_tokenization_detr.py @@ -36,12 +36,83 @@ class DetrTokenizerTest(unittest.TestCase): def setUp(self): super().setUp() + # single PIL image url = 'http://images.cocodataset.org/val2017/000000039769.jpg' self.img = Image.open(requests.get(url, stream=True).raw) + # batch of PIL images + annotations + base_url = "http://images.cocodataset.org/val2017/" + image_urls = ["000000087038.jpg", "000000578500.jpg", "000000261982.jpg"] + + images = [] + for image_url in image_urls: + images.append(Image.open(requests.get(base_url + image_url, stream=True).raw)) + self.images = images + + # each target is a dict with keys "boxes" and "area" + self.annotations = [{'boxes': [[253.21, 271.07, 59.59, 60.97], + [226.04, 229.31, 11.59, 30.41], + [257.85, 224.48, 44.13, 97.0], + [68.18, 238.19, 16.18, 42.88], + [79.16, 232.26, 28.22, 51.12], + [98.4, 234.28, 19.52, 46.46], + [326.86, 223.46, 13.11, 38.67], + [155.27, 246.34, 14.87, 21.99], + [298.61, 316.85, 63.91, 47.19], + [345.41, 173.41, 72.94, 185.41], + [239.72, 225.38, 10.64, 33.06], + [167.02, 234.0, 15.78, 37.46], + [209.68, 231.08, 9.15, 34.53], + [408.29, 231.25, 17.12, 34.97], + [204.14, 229.02, 7.33, 34.96], + [195.32, 228.06, 10.65, 37.18], + [1, 190, 638, 101]], + 'area': [1391.4269500000005, + 232.4970999999999, + 1683.128300000001, + 413.4482999999996, + 563.41615, + 363.24569999999994, + 261.10905000000054, + 152.3124499999996, + 1268.7499999999989, + 4686.905750000002, + 204.17735000000013, + 277.0192999999997, + 241.29070000000024, + 243.31384999999952, + 188.82489999999987, + 294.38859999999977, + 6443], + }, + {'boxes': [[268.66, 143.28, 61.01, 53.52], + [204.04, 139.88, 24.87, 36.95], + [157.51, 135.92, 26.51, 54.95], + [117.06, 135.86, 37.88, 56.7], + [192.39, 137.85, 14.19, 46.86], + [311.46, 149.17, 156.95, 88.98], + [499.59, 116.56, 140.41, 173.44], + [1.86, 147.85, 132.27, 99.36], + [124.21, 150.01, 89.08, 5.36], + [344.97, 92.63, 6.72, 30.25], + [441.77, 71.9, 10.01, 41.52], + [118.63, 153.62, 8.78, 20.32], + [291.45, 179.46, 15.1, 10.3], + [498.7, 115.61, 141.3, 174.39]], + }, + {'boxes': [[0.0, 8.53, 251.9, 214.62], + [409.89, 120.81, 47.11, 92.04], + [84.85, 0.0, 298.6, 398.22], + [159.71, 211.53, 189.89, 231.01], + [357.69, 110.26, 99.31, 90.43]], + } + ] + def get_tokenizer(self, **kwars): return DetrTokenizer() + # tests on single PIL image (inference only) + def test_tokenizer_no_resize(self): tokenizer = self.get_tokenizer() encoding = tokenizer(self.img, resize=False) @@ -54,4 +125,23 @@ def test_tokenizer(self): encoding = tokenizer(self.img) self.assertEqual(encoding["pixel_values"].shape, (1,3,800,1066)) - self.assertEqual(encoding["pixel_mask"].shape, (1,800,1066)) \ No newline at end of file + self.assertEqual(encoding["pixel_mask"].shape, (1,800,1066)) + + # tests on list of PIL images (inference only) + + def test_tokenizer_batch(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.images) + + self.assertEqual(encoding["pixel_values"].shape, (3,3,800,1201)) + self.assertEqual(encoding["pixel_mask"].shape, (3,800,1201)) + + # tests on list of PIL images (training) + # doesn't work yet (format of annotations?) + def test_tokenizer_batch_training(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.images, self.annotations) + + self.assertEqual(encoding["pixel_values"].shape, (3,3,800,1201)) + self.assertEqual(encoding["pixel_mask"].shape, (3,800,1201)) + From 55fccf6cec459baf464c7626334429379920713b Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 10:26:01 +0100 Subject: [PATCH 09/20] More improvements to DetrTokenizer --- src/transformers/models/detr/modeling_detr.py | 1 - .../models/detr/tokenization_detr.py | 35 ++++++++++--------- tests/test_tokenization_detr.py | 32 ++++++++--------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index f97455a7a89fd8..2e27da525398c9 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -186,7 +186,6 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - print('Max size', max_size) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index f79abd6f1932e1..1003407b927503 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -299,8 +299,8 @@ def __init__(self, size, max_size=None): self.size = size self.max_size = max_size - def __call__(self, img, target=None): - return resize(img, target, self.size, self.max_size) + def __call__(self, image, target=None): + return resize(image, target, self.size, self.max_size) # copied from https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/util/box_ops.py def box_xyxy_to_cxcywh(x): @@ -414,10 +414,10 @@ def __call__( :obj:`List[PIL.Image]`): The image or batch of images to be prepared. Each image can be a PIL image, numpy array or a Torch tensor. annotations (:obj:`Dict`, :obj:`List[Dict]`): - The annotations as either a single Python dictionary or a batch of Python dictionaries. Keys include "size", "area", - "boxes" and "masks". + The annotations as either a Python dictionary (in case of a single image) or a list of Python dictionaries (in case + of a batch of images). Keys include "boxes", "labels", "area" and "masks". resize (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether to resize images to a certain :obj:`size`. + Whether to resize images (and annotations) to a certain :obj:`size`. size (:obj:`int`, `optional`, defaults to :obj:`800`): Resize the input image to the given size. Only has an effect if :obj:`resize` is set to :obj:`True`. max_size (:obj:`int`, `optional`, defaults to :obj:`1333`): @@ -431,7 +431,7 @@ def __call__( and (isinstance(images[0], (PIL.Image.Image, np.ndarray, torch.Tensor))) ) - # make images a list of PIL images no matter what + # step 1: make images a list of PIL images no matter what if is_batched: if isinstance(images[0], np.ndarray): images = [Image.fromarray(image) for image in images] @@ -445,33 +445,36 @@ def __call__( if annotations is not None: annotations = [annotations] - # next: apply transformations (resizing + normalization) to both images and annotations + # step 2: define transformations (resizing + normalization) transformations = [ConvertCocoPolysToMask(),] if resize and size is not None: transformations.append(Resize(size=size, max_size=max_size)) if normalize: - normalize = Compose([ + normalization = Compose([ ToTensor(), Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - transformations.append(normalize) + transformations.append(normalization) transforms = Compose(transformations) + # step 3: apply transformations to both images and annotations + transformed_images = [] + transformed_annotations = [] if annotations is not None: - transformed_images = [] for image, annotation in zip(images, annotations): image, annotation = transforms(image, annotation) transformed_images.append(image) + transformed_annotations.append(annotation) else: - transformed_images = [transforms(image, None) for image in images] + transformed_images = [transforms(image, None)[0] for image in images] - # next, create NestedTensor which takes care of padding the pixels up to biggest image - # we don't need the transformed targets for this - samples = nested_tensor_from_tensor_list([x[0] for x in transformed_images]) + # step 4: create NestedTensor which takes care of padding the pixels up to biggest image + # and creation of mask. We don't need the transformed targets for this + samples = nested_tensor_from_tensor_list(transformed_images) - # return as dict + # return as BatchEncoding data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask, - 'labels': [x[1] for x in transformed_images] if annotations is not None else None} + 'labels': transformed_annotations if annotations is not None else None} encoded_inputs = BatchEncoding(data=data) return encoded_inputs diff --git a/tests/test_tokenization_detr.py b/tests/test_tokenization_detr.py index b30b0ab4cf56a6..af65088b4a97b4 100644 --- a/tests/test_tokenization_detr.py +++ b/tests/test_tokenization_detr.py @@ -112,36 +112,34 @@ def get_tokenizer(self, **kwars): return DetrTokenizer() # tests on single PIL image (inference only) - - def test_tokenizer_no_resize(self): - tokenizer = self.get_tokenizer() - encoding = tokenizer(self.img, resize=False) - - self.assertEqual(encoding["pixel_values"].shape, (1,3,480,640)) - self.assertEqual(encoding["pixel_mask"].shape, (1,480,640)) - def test_tokenizer(self): tokenizer = self.get_tokenizer() encoding = tokenizer(self.img) - self.assertEqual(encoding["pixel_values"].shape, (1,3,800,1066)) - self.assertEqual(encoding["pixel_mask"].shape, (1,800,1066)) + self.assertEqual(encoding["pixel_values"].shape, (1, 3, 800, 1066)) + self.assertEqual(encoding["pixel_mask"].shape, (1, 800, 1066)) + + # tests on single PIL image (inference only, with resize set to False) + def test_tokenizer_no_resize(self): + tokenizer = self.get_tokenizer() + encoding = tokenizer(self.img, resize=False) - # tests on list of PIL images (inference only) + self.assertEqual(encoding["pixel_values"].shape, (1, 3, 480, 640)) + self.assertEqual(encoding["pixel_mask"].shape, (1, 480, 640)) + # tests on batch of PIL images (inference only) def test_tokenizer_batch(self): tokenizer = self.get_tokenizer() encoding = tokenizer(self.images) - self.assertEqual(encoding["pixel_values"].shape, (3,3,800,1201)) - self.assertEqual(encoding["pixel_mask"].shape, (3,800,1201)) + self.assertEqual(encoding["pixel_values"].shape, (3, 3, 1120, 1332)) + self.assertEqual(encoding["pixel_mask"].shape, (3, 1120, 1332)) - # tests on list of PIL images (training) - # doesn't work yet (format of annotations?) + # tests on batch of PIL images (training, i.e. with annotations) def test_tokenizer_batch_training(self): tokenizer = self.get_tokenizer() encoding = tokenizer(self.images, self.annotations) - self.assertEqual(encoding["pixel_values"].shape, (3,3,800,1201)) - self.assertEqual(encoding["pixel_mask"].shape, (3,800,1201)) + self.assertEqual(encoding["pixel_values"].shape, (3, 3, 1120, 1332)) + self.assertEqual(encoding["pixel_mask"].shape, (3, 1120, 1332)) From 3a892839e8da4caa049adc9e10968da02c1cd733 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 10:28:20 +0100 Subject: [PATCH 10/20] Add print statements --- src/transformers/models/detr/modeling_detr.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 2e27da525398c9..1b038ef5f71893 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1452,6 +1452,15 @@ def forward( # First, sent images through Backbone to obtain the features (includes features map, mask and position embeddings) if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) + # tensors are of shape (batch_size, num_channels, height, width) + print("Shape of tensors:") + print(samples.tensors.shape) + print("First few elements:") + print(samples.tensors[0,:3,:3,:3]) + print("Shape of mask:") + print(samples.mask.shape) + print("First few elements of mask:") + print(samples.tensors[0,:3,:3]) features, position_embeddings_list = self.backbone(samples) src, mask = features[-1].decompose() From a252e35e25ff9010bb0afb35d40e82dca01b5aaf Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 11:37:14 +0100 Subject: [PATCH 11/20] Make base model in DetrForObjectDetection much shorter --- .gitattributes.txt | 6 - ..._original_pytorch_checkpoint_to_pytorch.py | 5 + src/transformers/models/detr/modeling_detr.py | 146 ++++-------------- tests/test_modeling_detr.py | 4 +- 4 files changed, 40 insertions(+), 121 deletions(-) delete mode 100644 .gitattributes.txt diff --git a/.gitattributes.txt b/.gitattributes.txt deleted file mode 100644 index 06f7128826c165..00000000000000 --- a/.gitattributes.txt +++ /dev/null @@ -1,6 +0,0 @@ -# Set the default behavior, in case people don't have core.autocrlf set -* text=auto - -*.py eol=lf -*.rst eol=lf -*.md eol=lf \ No newline at end of file diff --git a/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py index 27c4cc3cc2a085..18fea997ae15bb 100644 --- a/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py @@ -245,6 +245,11 @@ def convert_detr_checkpoint(task, backbone, dilation, pytorch_dump_folder_path): # rename classification heads for src, dest in rename_keys_object_detection_model: rename_key(state_dict, src, dest) + # important: we need to prepend "model." to each of the base model keys as DetrForObjectDetection calls the base model like this + for key in state_dict.copy().keys(): + if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"): + val = state_dict.pop(key) + state_dict["model." + key] = val # finally, create model and load state dict model = DetrForObjectDetection(config).eval() model.load_state_dict(state_dict) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 1b038ef5f71893..b1e896efae79a6 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -74,6 +74,17 @@ class BaseModelOutputWithCrossAttentionsAndIntermediateHiddenStates(BaseModelOut intermediate_hidden_states: Optional[torch.FloatTensor] = None +@dataclass +class Seq2SeqModelOutputWithIntermediateHiddenStates(Seq2SeqModelOutput): + """ + This class adds one attribute to Seq2SeqModelOutput, namely an optional stack of intermediate decoder + activations, i.e. the output of each decoder layer, each of them gone through a layernorm. + Args: + intermediate_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(config.decoder_layers, batch_size, sequence_length, hidden_size)`): + """ + + intermediate_hidden_states: Optional[torch.FloatTensor] = None + @dataclass class DetrObjectDetectionOutput(ModelOutput): """ @@ -1422,12 +1433,12 @@ def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(DETR_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - tokenizer_class=_TOKENIZER_FOR_DOC, - checkpoint="facebook/detr-resnet-50", - output_type=Seq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - ) + # @add_code_sample_docstrings( + # tokenizer_class=_TOKENIZER_FOR_DOC, + # checkpoint="facebook/detr-resnet-50", + # output_type=Seq2SeqModelOutput, + # config_class=_CONFIG_FOR_DOC, + # ) def forward( self, samples: NestedTensor=None, @@ -1453,14 +1464,6 @@ def forward( if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) # tensors are of shape (batch_size, num_channels, height, width) - print("Shape of tensors:") - print(samples.tensors.shape) - print("First few elements:") - print(samples.tensors[0,:3,:3,:3]) - print("Shape of mask:") - print(samples.mask.shape) - print("First few elements of mask:") - print(samples.tensors[0,:3,:3]) features, position_embeddings_list = self.backbone(samples) src, mask = features[-1].decompose() @@ -1518,7 +1521,7 @@ def forward( if not return_dict: return decoder_outputs + encoder_outputs - return Seq2SeqModelOutput( + return Seq2SeqModelOutputWithIntermediateHiddenStates( last_hidden_state=decoder_outputs.last_hidden_state, #past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, @@ -1527,6 +1530,7 @@ def forward( encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, + intermediate_hidden_states=decoder_outputs.intermediate_hidden_states, ) @@ -1539,18 +1543,8 @@ class DetrForObjectDetection(DetrPreTrainedModel): def __init__(self, config: DetrConfig): super().__init__(config) - # Create backbone + positional encoding - backbone = Backbone(config.backbone, config.train_backbone, config.masks, config.dilation) - position_embeddings = build_position_encoding(config) - self.backbone = Joiner(backbone, position_embeddings) - - # Create projection layer - self.input_projection = nn.Conv2d(backbone.num_channels, config.d_model, kernel_size=1) - - self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) - - self.encoder = DetrEncoder(config) - self.decoder = DetrDecoder(config) + # DETR base model + self.model = DetrModel(config) # Object detection heads self.class_labels_classifier = nn.Linear(config.d_model, config.num_labels + 1) @@ -1559,20 +1553,6 @@ def __init__(self, config: DetrConfig): self.init_weights() - # def get_input_embeddings(self): - # return self.shared - - # def set_input_embeddings(self, value): - # self.shared = value - # self.encoder.embed_tokens = self.shared - # self.decoder.embed_tokens = self.shared - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - # copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): @@ -1611,75 +1591,15 @@ def forward( be a :obj:`torch.LongTensor` of len :obj:`(number of bounding boxes in the image,)` and the boxes a :obj:`torch.FloatTensor` of shape :obj:`(number of bounding boxes in the image, 4)`. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # First, sent images through Backbone to obtain the features (includes features map, mask and position embeddings) - if isinstance(samples, (list, torch.Tensor)): - samples = nested_tensor_from_tensor_list(samples) - features, position_embeddings_list = self.backbone(samples) - - src, mask = features[-1].decompose() - assert mask is not None - - # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) - src = self.input_projection(src) - - # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC - # In other words, turn their shape into (batch_size, sequence_length, hidden_size) - batch_size, c, h, w = src.shape - src = src.flatten(2).permute(0, 2, 1) - position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1) - mask = ~mask.flatten(1) - - # Fourth, sent src + mask + position embeddings through encoder - # src is a Tensor of shape (batch_size, heigth*width, hidden_size) - # mask is a Tensor of shape (batch_size, heigth*width) - if encoder_outputs is None: - encoder_outputs = self.encoder( - inputs_embeds=src, - attention_mask=mask, - position_embeddings=position_embeddings, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - # Fifth, sent queries i.e. tgt (initialized with zeros), query position embeddings + position embeddings - # through the decoder (which is conditioned on the encoder output) - query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) - tgt = torch.zeros_like(query_position_embeddings) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - inputs_embeds=tgt, - attention_mask=None, - position_embeddings=position_embeddings, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + # First, sent images through DETR base model to obtain encoder + decoder outputs + outputs = self.model(samples) # class logits + predicted bounding boxes # to do: make this as efficient as the original implementation - pred_logits = self.class_labels_classifier(decoder_outputs[0]) - pred_boxes = self.bbox_predictor(decoder_outputs[0]).sigmoid() + pred_logits = self.class_labels_classifier(outputs.last_hidden_state) + pred_boxes = self.bbox_predictor(outputs.last_hidden_state).sigmoid() loss, auxiliary_outputs = None, None if labels is not None: @@ -1722,7 +1642,7 @@ def forward( outputs['pred_logits'] = pred_logits outputs['pred_boxes'] = pred_boxes if self.config.auxiliary_loss: - intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[5] + intermediate = outputs.intermediate_hidden_states if return_dict else outputs[6] outputs_class = self.class_labels_classifier(intermediate) outputs_coord = self.bbox_predictor(intermediate).sigmoid() auxiliary_outputs = self._set_aux_loss(outputs_class, outputs_coord) @@ -1746,14 +1666,14 @@ def forward( pred_boxes=pred_boxes, auxiliary_outputs=auxiliary_outputs, #past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, ) - + # copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py class SetCriterion(nn.Module): diff --git a/tests/test_modeling_detr.py b/tests/test_modeling_detr.py index 563d703d7a572e..6700487d8b671e 100644 --- a/tests/test_modeling_detr.py +++ b/tests/test_modeling_detr.py @@ -323,7 +323,7 @@ class DetrModelIntegrationTests(unittest.TestCase): # return DetrTokenizer.from_pretrained('facebook/detr-resnet-50') def test_inference_no_head(self): - model = DetrModel.from_pretrained('nielsr/detr-resnet-50').to(torch_device) + model = DetrModel.from_pretrained('nielsr/detr-resnet-50-new').to(torch_device) model.eval() img = prepare_img().to(torch_device) @@ -339,7 +339,7 @@ def test_inference_no_head(self): def test_inference_object_detection_head(self): - model = DetrForObjectDetection.from_pretrained('nielsr/detr-resnet-50').to(torch_device) + model = DetrForObjectDetection.from_pretrained('nielsr/detr-resnet-50-new').to(torch_device) model.eval() img = prepare_img().to(torch_device) From 05cfe7e019a90c87308b01ccde71ec987b0633a9 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 14:23:43 +0100 Subject: [PATCH 12/20] Add print statements --- src/transformers/models/detr/modeling_detr.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index b1e896efae79a6..0020df5995c9e4 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -341,12 +341,16 @@ def __init__(self, backbone, position_embedding): def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) + print("Xs:") + print(xs) out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): out.append(x) # position encoding pos.append(self[1](x).to(x.tensors.dtype)) + print("Pos:") + print(pos) return out, pos From 2f505cf6ab7c59eb504b180abf437fc4e32cd89b Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 14:37:11 +0100 Subject: [PATCH 13/20] Add more print statements --- src/transformers/models/detr/modeling_detr.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 0020df5995c9e4..7358076ea0da54 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -341,16 +341,18 @@ def __init__(self, backbone, position_embedding): def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) - print("Xs:") - print(xs) + print("Shape of backbone output:") + print(xs["0"].shape) out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): out.append(x) # position encoding pos.append(self[1](x).to(x.tensors.dtype)) - print("Pos:") - print(pos) + + print(len(pos)) + print("Shape of position embeddings:") + print(pos[0].shape) return out, pos From c63413d3cdd20cf5520532112ea66c8324ee945f Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 14:39:41 +0100 Subject: [PATCH 14/20] Add more print statements --- src/transformers/models/detr/modeling_detr.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 7358076ea0da54..3bb71b5c77e38a 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -341,8 +341,11 @@ def __init__(self, backbone, position_embedding): def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) - print("Shape of backbone output:") - print(xs["0"].shape) + print("The backbone outputs a NestedTensor with a Tensor and a Mask.") + print("Shape of backbone tensor:") + print(xs["0"].tensors.shape) + print("Shape of backbone mask:") + print(xs["0"].mask.shape) out: List[NestedTensor] = [] pos = [] for name, x in xs.items(): From bdd48207008a59d8e1b4e1c4e7c848a8c0753d29 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 15:35:12 +0100 Subject: [PATCH 15/20] Tests passing with inputs prepared by DetrTokenizer --- src/transformers/models/detr/modeling_detr.py | 109 ++++++++++-------- tests/test_modeling_detr.py | 25 +++- 2 files changed, 85 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3bb71b5c77e38a..d8cd5513da2b3f 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -300,26 +300,26 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) self.num_channels = num_channels - # def forward_new(self, pixel_values: Union[torch.Tensor, list[torch.Tensor]], pixel_mask: Optional[torch.Tensor]): - # xs = self.body(pixel_values) + def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): + # send pixel_values through the IntermediateLayerGetter + output_dict = self.body(pixel_values) + # currently there's no support for intermediate layers of the backbone + feature_map = output_dict["0"] + assert pixel_mask is not None + # we downsample the pixel_mask to match the feature map + mask = F.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0] + return feature_map, mask + + # # this one should be removed in the future + # def forward(self, tensor_list: NestedTensor): + # xs = self.body(tensor_list.tensors) # out: Dict[str, NestedTensor] = {} # for name, x in xs.items(): - # m = pixel_mask + # m = tensor_list.mask # assert m is not None # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] # out[name] = NestedTensor(x, mask) # return out - - # this one should be removed in the future - def forward(self, tensor_list: NestedTensor): - xs = self.body(tensor_list.tensors) - out: Dict[str, NestedTensor] = {} - for name, x in xs.items(): - m = tensor_list.mask - assert m is not None - mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] - out[name] = NestedTensor(x, mask) - return out class Backbone(BackboneBase): @@ -339,25 +339,38 @@ class Joiner(nn.Sequential): def __init__(self, backbone, position_embedding): super().__init__(backbone, position_embedding) - def forward(self, tensor_list: NestedTensor): - xs = self[0](tensor_list) - print("The backbone outputs a NestedTensor with a Tensor and a Mask.") - print("Shape of backbone tensor:") - print(xs["0"].tensors.shape) - print("Shape of backbone mask:") - print(xs["0"].mask.shape) - out: List[NestedTensor] = [] - pos = [] - for name, x in xs.items(): - out.append(x) - # position encoding - pos.append(self[1](x).to(x.tensors.dtype)) + def forward(self, pixel_values, pixel_mask): + print("Hello we're here") + print(pixel_values.shape) + print(pixel_mask.shape) - print(len(pos)) - print("Shape of position embeddings:") - print(pos[0].shape) + # first, send pixel_values and pixel_mask through backbone to obtain updated feature_map and pixel_mask + feature_map, pixel_mask = self[0](pixel_values, pixel_mask) + + # next, create position embeddings for the outputs of the outputs_dict + pos = self[1](feature_map, pixel_mask).to(feature_map.dtype) + + return feature_map, pixel_mask, pos + + # def forward(self, tensor_list: NestedTensor): + # xs = self[0](tensor_list) + # print("The backbone outputs a NestedTensor with a Tensor and a Mask.") + # print("Shape of backbone tensor:") + # print(xs["0"].tensors.shape) + # print("Shape of backbone mask:") + # print(xs["0"].mask.shape) + # out: List[NestedTensor] = [] + # pos = [] + # for name, x in xs.items(): + # out.append(x) + # # position encoding + # pos.append(self[1](x).to(x.tensors.dtype)) + + # print(len(pos)) + # print("Shape of position embeddings:") + # print(pos[0].shape) - return out, pos + # return out, pos def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): @@ -422,9 +435,9 @@ def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=N scale = 2 * math.pi self.scale = scale - def forward(self, tensor_list: NestedTensor): - x = tensor_list.tensors - mask = tensor_list.mask + def forward(self, pixel_values, pixel_mask): + x = pixel_values + mask = pixel_mask assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) @@ -1450,7 +1463,9 @@ def get_decoder(self): # ) def forward( self, - samples: NestedTensor=None, + #samples: NestedTensor=None, + pixel_values, + pixel_mask, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, @@ -1469,23 +1484,25 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # First, sent images through Backbone to obtain the features (includes features map, mask and position embeddings) - if isinstance(samples, (list, torch.Tensor)): - samples = nested_tensor_from_tensor_list(samples) - # tensors are of shape (batch_size, num_channels, height, width) - features, position_embeddings_list = self.backbone(samples) + # First, sent pixel_values + pixel_mask through Backbone to obtain the features + # (includes features map, downsampled mask and position embeddings) + # pixel_values should be of shape (batch_size, num_channels, height, width) + # pixel_mask should be of shape (batch_size, height, width) + feature_map, mask, position_embeddings = self.backbone(pixel_values, pixel_mask) - src, mask = features[-1].decompose() assert mask is not None # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) - src = self.input_projection(src) + src = self.input_projection(feature_map) # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC # In other words, turn their shape into (batch_size, sequence_length, hidden_size) batch_size, c, h, w = src.shape src = src.flatten(2).permute(0, 2, 1) - position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1) + position_embeddings = position_embeddings.flatten(2).permute(0, 2, 1) + + print("Shape of position embeddings:") + print(position_embeddings.shape) mask = ~mask.flatten(1) # Fourth, sent src + mask + position embeddings through encoder @@ -1580,7 +1597,9 @@ def _set_aux_loss(self, outputs_class, outputs_coord): # ) def forward( self, - samples: NestedTensor=None, + #samples: NestedTensor=None, + pixel_values, + pixel_mask, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, @@ -1603,7 +1622,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # First, sent images through DETR base model to obtain encoder + decoder outputs - outputs = self.model(samples) + outputs = self.model(pixel_values, pixel_mask) # class logits + predicted bounding boxes # to do: make this as efficient as the original implementation diff --git a/tests/test_modeling_detr.py b/tests/test_modeling_detr.py index 6700487d8b671e..05b388b27588f0 100644 --- a/tests/test_modeling_detr.py +++ b/tests/test_modeling_detr.py @@ -315,6 +315,17 @@ def prepare_img(): return img +def prepare_detr_inputs(): + url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + im = Image.open(requests.get(url, stream=True).raw) + + tokenizer = DetrTokenizer() + + encoding = tokenizer(im) + + return encoding + + @require_torch @slow class DetrModelIntegrationTests(unittest.TestCase): @@ -325,10 +336,13 @@ class DetrModelIntegrationTests(unittest.TestCase): def test_inference_no_head(self): model = DetrModel.from_pretrained('nielsr/detr-resnet-50-new').to(torch_device) model.eval() - img = prepare_img().to(torch_device) + + encoding = prepare_detr_inputs() + pixel_values = encoding['pixel_values'].to(torch_device) + pixel_mask = encoding['pixel_mask'].to(torch_device) with torch.no_grad(): - outputs = model(img) + outputs = model(pixel_values, pixel_mask) expected_shape = torch.Size((1, 100, 256)) assert outputs.last_hidden_state.shape == expected_shape @@ -341,10 +355,13 @@ def test_inference_no_head(self): def test_inference_object_detection_head(self): model = DetrForObjectDetection.from_pretrained('nielsr/detr-resnet-50-new').to(torch_device) model.eval() - img = prepare_img().to(torch_device) + + encoding = prepare_detr_inputs() + pixel_values = encoding['pixel_values'].to(torch_device) + pixel_mask = encoding['pixel_mask'].to(torch_device) with torch.no_grad(): - outputs = model(img) + outputs = model(pixel_values, pixel_mask) expected_shape_logits = torch.Size((1, model.config.num_queries, model.config.num_labels + 1)) self.assertEqual(outputs.pred_logits.shape, expected_shape_logits) From 08b44c2e7672c68618b6ab10205f447c1d3343f3 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 15:51:10 +0100 Subject: [PATCH 16/20] Fix to outputs of DetrForObjectDetection --- src/transformers/models/detr/modeling_detr.py | 12 +++++++++++- src/transformers/models/detr/tokenization_detr.py | 7 +++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index d8cd5513da2b3f..61ffffff9de033 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1622,7 +1622,17 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # First, sent images through DETR base model to obtain encoder + decoder outputs - outputs = self.model(pixel_values, pixel_mask) + outputs = self.model(pixel_values, pixel_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=Nodecoder_attention_maskne, + encoder_outputs=Nencoder_outputsone, + #past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict,) # class logits + predicted bounding boxes # to do: make this as efficient as the original implementation diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index 1003407b927503..05baf2e55f5ac7 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -473,8 +473,11 @@ def __call__( samples = nested_tensor_from_tensor_list(transformed_images) # return as BatchEncoding - data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask, - 'labels': transformed_annotations if annotations is not None else None} + data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask}, + + if annotations is not None: + data['labels'] = transformed_annotations + encoded_inputs = BatchEncoding(data=data) return encoded_inputs From b3d36b8fd4c52f518f5b9b2c7c7b8578ab02f4d6 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Tue, 9 Feb 2021 15:54:22 +0100 Subject: [PATCH 17/20] Bug fixes --- src/transformers/models/detr/modeling_detr.py | 4 ++-- src/transformers/models/detr/tokenization_detr.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 61ffffff9de033..82f6d6443fb28b 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1624,8 +1624,8 @@ def forward( # First, sent images through DETR base model to obtain encoder + decoder outputs outputs = self.model(pixel_values, pixel_mask, decoder_input_ids=decoder_input_ids, - decoder_attention_mask=Nodecoder_attention_maskne, - encoder_outputs=Nencoder_outputsone, + decoder_attention_mask=decoder_attention_mask, + encoder_outputs=encoder_outputs, #past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index 05baf2e55f5ac7..731e23cdf63960 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -473,11 +473,11 @@ def __call__( samples = nested_tensor_from_tensor_list(transformed_images) # return as BatchEncoding - data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask}, + data = {'pixel_values': samples.tensors, 'pixel_mask': samples.mask} if annotations is not None: data['labels'] = transformed_annotations - + encoded_inputs = BatchEncoding(data=data) return encoded_inputs From 0bb49308ce44ac7d3feaf61f6cdd8ba6a0e04a2d Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 10 Feb 2021 17:23:20 +0100 Subject: [PATCH 18/20] Some cleanup --- src/transformers/models/detr/modeling_detr.py | 12 +++--------- src/transformers/models/detr/tokenization_detr.py | 1 - 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 82f6d6443fb28b..e1789fe335140a 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -339,11 +339,7 @@ class Joiner(nn.Sequential): def __init__(self, backbone, position_embedding): super().__init__(backbone, position_embedding) - def forward(self, pixel_values, pixel_mask): - print("Hello we're here") - print(pixel_values.shape) - print(pixel_mask.shape) - + def forward(self, pixel_values, pixel_mask): # first, send pixel_values and pixel_mask through backbone to obtain updated feature_map and pixel_mask feature_map, pixel_mask = self[0](pixel_values, pixel_mask) @@ -472,8 +468,8 @@ def reset_parameters(self): nn.init.uniform_(self.row_embeddings.weight) nn.init.uniform_(self.column_embeddings.weight) - def forward(self, tensor_list: NestedTensor): - x = tensor_list.tensors + def forward(self, pixel_values, pixel_mask=None): + x = pixel_values h, w = x.shape[-2:] i = torch.arange(w, device=x.device) j = torch.arange(h, device=x.device) @@ -1501,8 +1497,6 @@ def forward( src = src.flatten(2).permute(0, 2, 1) position_embeddings = position_embeddings.flatten(2).permute(0, 2, 1) - print("Shape of position embeddings:") - print(position_embeddings.shape) mask = ~mask.flatten(1) # Fourth, sent src + mask + position embeddings through encoder diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index 731e23cdf63960..11e4b3346ba0b9 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -108,7 +108,6 @@ def __repr__(self): def nested_tensor_from_tensor_list(tensor_list: Union[List[Tensor], torch.Tensor]): # TODO make this more n - print(tensor_list[0].shape) if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX From 41806338bdea9447abfd2badf3d39719949f0147 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Wed, 10 Feb 2021 17:40:15 +0100 Subject: [PATCH 19/20] Invert input mask to be Transformers-compliant --- src/transformers/models/detr/modeling_detr.py | 7 +++---- src/transformers/models/detr/tokenization_detr.py | 5 +++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index e1789fe335140a..f0be6dd8833767 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -435,9 +435,8 @@ def forward(self, pixel_values, pixel_mask): x = pixel_values mask = pixel_mask assert mask is not None - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale @@ -1497,7 +1496,7 @@ def forward( src = src.flatten(2).permute(0, 2, 1) position_embeddings = position_embeddings.flatten(2).permute(0, 2, 1) - mask = ~mask.flatten(1) + mask = mask.flatten(1) # Fourth, sent src + mask + position embeddings through encoder # src is a Tensor of shape (batch_size, heigth*width, hidden_size) diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index 11e4b3346ba0b9..f11f371645e0f2 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -122,10 +122,10 @@ def nested_tensor_from_tensor_list(tensor_list: Union[List[Tensor], torch.Tensor dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + mask = torch.zeros((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], :img.shape[2]] = False + m[: img.shape[1], :img.shape[2]] = True else: raise ValueError('Not supported') return NestedTensor(tensor, mask) @@ -133,6 +133,7 @@ def nested_tensor_from_tensor_list(tensor_list: Union[List[Tensor], torch.Tensor # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. +# Note: inverting mask has not yet been taken into account @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] From 403aeeb0ea772a223f8a923e3f55824b42f91134 Mon Sep 17 00:00:00 2001 From: Niels Rogge Date: Thu, 11 Feb 2021 11:08:36 +0100 Subject: [PATCH 20/20] Improve docstrings + first draft to support directories for DetrTokenizer --- src/transformers/models/detr/modeling_detr.py | 155 +++++++----------- .../models/detr/tokenization_detr.py | 77 +++++++-- 2 files changed, 129 insertions(+), 103 deletions(-) diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index f0be6dd8833767..5fca5d9072a093 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -911,46 +911,24 @@ def dummy_inputs(self): weights. """ -DETR_GENERATION_EXAMPLE = r""" - Summarization example:: - - >>> from transformers import DetrTokenizer, DetrForConditionalGeneration, DetrConfig - - >>> model = DetrForConditionalGeneration.from_pretrained('facebook/detr-resnet-50') - >>> tokenizer = DetrTokenizer.from_pretrained('facebook/detr-resnet-50') - - >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." - >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') - - >>> # Generate Summary - >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) - >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) -""" - DETR_INPUTS_DOCSTRING = r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. - Indices can be obtained using :class:`~transformers.DetrTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for - details. + Pixel values can be obtained using :class:`~transformers.DetrTokenizer`. See + :meth:`transformers.DetrTokenizer.__call__` for details. - `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + pixel_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding pixel values. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). `What are attention masks? <../glossary.html#attention-mask>`__ - decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): - Provide for translation and summarization training. By default, the model will create this tensor by - shifting the :obj:`input_ids` to the right, following the paper. + decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): - Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will - also be used by default. + Not used by default. If you want to change padding behavior, you should read :func:`modeling_detr._prepare_decoder_inputs` and modify to your needs. See diagram 1 in `the paper `__ for more @@ -960,12 +938,6 @@ def dummy_inputs(self): :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. - - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices into associated @@ -997,6 +969,11 @@ class DetrEncoder(DetrPreTrainedModel): Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a :class:`DetrEncoderLayer`. + The encoder updates the flattened feature map through multiple self-attention layers. + + Small tweaks for DETR: + - position_embeddings are added to the forward pass. + Args: config: DetrConfig embed_tokens (torch.nn.Embedding): output embedding @@ -1032,9 +1009,9 @@ def __init__(self, config: DetrConfig): def forward( self, - input_ids=None, - attention_mask=None, + #input_ids=None, inputs_embeds=None, + attention_mask=None, position_embeddings=None, output_attentions=None, output_hidden_states=None, @@ -1042,26 +1019,20 @@ def forward( ): r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using :class:`~transformers.DetrTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` - for details. - - `What are input IDs? <../glossary.html#input-ids>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + Mask to avoid performing attention on padding pixel features. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + - 1 for pixel features that are real (i.e. **not masked**), + - 0 for pixel features that are padding (i.e. **masked**). `What are attention masks? <../glossary.html#attention-mask>`__ - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded - representation. This is useful if you want more control over how to convert :obj:`input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. + + position_embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Position embeddings that are added to the queries and keys in each self-attention layer. + output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -1158,6 +1129,8 @@ class DetrDecoder(DetrPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`DetrDecoderLayer`. + The decoder updates the query embeddings through multiple self-attention and cross-attention layers. + Some small tweaks for DETR: - position_embeddings and query_position_embeddings are added to the forward pass. - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. @@ -1195,12 +1168,12 @@ def __init__(self, config: DetrConfig, embed_tokens: Optional[nn.Embedding] = No def forward( self, - input_ids=None, + #input_ids=None, + inputs_embeds=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, - inputs_embeds=None, position_embeddings=None, query_position_embeddings=None, use_cache=None, @@ -1210,31 +1183,25 @@ def forward( ): r""" Args: - input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using :class:`~transformers.DetrTokenizer`. See - :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` - for details. - - `What are input IDs? <../glossary.html#input-ids>`__ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + The query embeddings that are passed into the decoder. + attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + Mask to avoid performing attention on certain queries. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. `What are attention masks? <../glossary.html#attention-mask>`__ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). `What are attention masks? <../glossary.html#attention-mask>`__ past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): @@ -1245,10 +1212,6 @@ def forward( :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded - representation. This is useful if you want more control over how to convert :obj:`input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. @@ -1265,6 +1228,7 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # (Niels) following lines are not required as DETR doesn't use input_ids and inputs_embeds # retrieve input_ids and inputs_embeds # if input_ids is not None and inputs_embeds is not None: # raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -1279,9 +1243,9 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: - # to do: should be updated, because no input_ids here - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + # to do: should be updated, because no input_ids here + # if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale # added this (Niels) to infer input_shape: if inputs_embeds is not None: @@ -1460,7 +1424,7 @@ def forward( self, #samples: NestedTensor=None, pixel_values, - pixel_mask, + pixel_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, @@ -1479,6 +1443,12 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + # First, sent pixel_values + pixel_mask through Backbone to obtain the features # (includes features map, downsampled mask and position embeddings) # pixel_values should be of shape (batch_size, num_channels, height, width) @@ -1488,23 +1458,22 @@ def forward( assert mask is not None # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) - src = self.input_projection(feature_map) + projected_feature_map = self.input_projection(feature_map) # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC # In other words, turn their shape into (batch_size, sequence_length, hidden_size) - batch_size, c, h, w = src.shape - src = src.flatten(2).permute(0, 2, 1) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) position_embeddings = position_embeddings.flatten(2).permute(0, 2, 1) - mask = mask.flatten(1) + flattened_mask = mask.flatten(1) - # Fourth, sent src + mask + position embeddings through encoder - # src is a Tensor of shape (batch_size, heigth*width, hidden_size) - # mask is a Tensor of shape (batch_size, heigth*width) + # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder + # flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size) + # flattened_mask is a Tensor of shape (batch_size, heigth*width) if encoder_outputs is None: encoder_outputs = self.encoder( - inputs_embeds=src, - attention_mask=mask, + inputs_embeds=flattened_features, + attention_mask=flattened_mask, position_embeddings=position_embeddings, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1520,16 +1489,16 @@ def forward( # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output) query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) - tgt = torch.zeros_like(query_position_embeddings) + queries = torch.zeros_like(query_position_embeddings) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( - inputs_embeds=tgt, + inputs_embeds=queries, attention_mask=None, position_embeddings=position_embeddings, query_position_embeddings=query_position_embeddings, encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=mask, + encoder_attention_mask=flattened_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, diff --git a/src/transformers/models/detr/tokenization_detr.py b/src/transformers/models/detr/tokenization_detr.py index f11f371645e0f2..26b618ce9d560a 100644 --- a/src/transformers/models/detr/tokenization_detr.py +++ b/src/transformers/models/detr/tokenization_detr.py @@ -14,11 +14,11 @@ # limitations under the License. """Tokenization class for DETR.""" -import json -import os -from itertools import groupby from typing import Dict, List, Optional, Tuple, Union -import random +from pathlib import Path +import json +from collections import defaultdict +import time import numpy as np import torch @@ -390,12 +390,16 @@ def __init__( ) self._word_delimiter_token = word_delimiter_token self.do_lower_case = do_lower_case + + # load dataset + self.dataset ,self.anns, self.cats, self.imgs = dict(), dict(), dict(), dict() + self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) @add_end_docstrings(DETR_KWARGS_DOCSTRING) def __call__( self, - images: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor]], - annotations: Optional[Union[Dict, List[Dict]]] = None, + images: Union[PIL.Image.Image, np.ndarray, torch.Tensor, List[PIL.Image.Image], List[np.ndarray], List[torch.Tensor], str], + annotations: Optional[Union[Dict, List[Dict], str]] = None, padding: Union[bool, str] = True, return_mask: Union[bool, str] = True, resize: Optional[bool] = True, @@ -410,12 +414,20 @@ def __call__( the largest image in a batch while creating a pixel mask for each image indicating which pixels are real and which are padding. Args: - images (:obj:`PIL.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`, - :obj:`List[PIL.Image]`): - The image or batch of images to be prepared. Each image can be a PIL image, numpy array or a Torch tensor. + images (:obj:`PIL.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`, :obj:`str`): + The image or batch of images to be prepared. Each image can be a PIL image, numpy array or a Torch tensor. You can also provide the name of the path + to a directory containing the images. annotations (:obj:`Dict`, :obj:`List[Dict]`): The annotations as either a Python dictionary (in case of a single image) or a list of Python dictionaries (in case - of a batch of images). Keys include "boxes", "labels", "area" and "masks". + of a batch of images). Each dictionary should include the following keys: + + - boxes (:obj:`List[List[float]]`): the coordinates of the N bounding boxes in [x0, y0, x1, y1] format: the x and y coordinate of the top left and the height and width. + - labels (:obj:`List[int]`): the label for each bounding box. 0 represents the background (i.e. 'no object') class. + - (optionally) masks (:obj:`List[List[float]]`): the segmentation masks for each of the objects. + - (optionally) keypoints (FloatTensor[N, K, 3]): for each one of the N objects, it contains the K keypoints in [x, y, visibility] format, defining the object. visibility=0 means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adapt references/detection/transforms.py for your new keypoint representation. + + Instead, you can also provide a path to a json file in COCO format. + resize (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to resize images (and annotations) to a certain :obj:`size`. size (:obj:`int`, `optional`, defaults to :obj:`800`): @@ -426,6 +438,51 @@ def __call__( Whether to apply standard ImageNet mean/std normalization of images. """ + # images can be a path to a directory, and annotations can be a path to a json file in COCO format + # the following is based on the COCO class of https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py + if isinstance(images, str) and isinstance(annotations, str): + with open(annotations) as f: + if verbose: + logger.info('Loading annotations into memory...') + tic = time.time() + dataset = json.load(open(annotations, 'r')) + assert type(dataset)==dict, 'Annotation file format {} not supported'.format(type(dataset)) + if verbose: + logger.info('Done (t={:0.2f}s)'.format(time.time()- tic)) + + imgToAnns = defaultdict(list) + anns, imgs = {} + if verbose: + # create index + logger.info('Creating index...') + + if 'images' in self.dataset: + for img in self.dataset['images']: + imgs[img['id']] = img + + if 'annotations' in self.dataset: + for ann in self.dataset['annotations']: + imgToAnns[ann['image_id']].append(ann) + anns[ann['id']] = ann + + if verbose: + logger.info('Index created!') + + # next, turn into list of dicts (each dict should correspond to an image in 'images') + # this is based on https://github.com/facebookresearch/detr/blob/a54b77800eb8e64e3ad0d8237789fcbf2f8350c5/datasets/coco.py#L50 + images_path = Path(images).glob('*/*.jpg') + images = [] + annotations = [] + for image_file_name in images_path: + if image_file_name.is_file(): + images.append(Image.open(image_file_name).convert("RGB")) + annotation_dict = {} + annotation_dict['boxes'] = [obj["bbox"] for obj in imgToAnns[image_file_name]] + annotation_dict['area'] = [obj["area"] for obj in imgToAnns[image_file_name]] + annotation_dict['classes'] = [obj["category_id"] for obj in imgToAnns[image_file_name]] + annotations.append(annotation_dict) + + is_batched = bool( isinstance(images, (list, tuple)) and (isinstance(images[0], (PIL.Image.Image, np.ndarray, torch.Tensor)))