From 9f1260971f041f4dcabf063ca2964847c3e5fc2a Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 13 Apr 2021 00:07:10 +0200 Subject: [PATCH] Add DeiT (PyTorch) (#11056) * First draft of deit * More improvements * Remove DeiTTokenizerFast from init * Conversion script works * Add DeiT to ViT conversion script * Add tests, add head model, add support for deit in vit conversion script * Update model checkpoint names * Update image_mean and image_std, set resample to bicubic * Improve docs * Docs improvements * Add DeiTForImageClassificationWithTeacher to init * Address comments by @sgugger * Improve feature extractors * Make fix-copies * Minor fixes * Address comments by @patil-suraj * All models uploaded * Fix tests * Remove labels argument from DeiTForImageClassificationWithTeacher * Fix-copies, style and quality * Fix tests * Fix typo * Multiple docs improvements * More docs fixes --- README.md | 1 + docs/source/index.rst | 82 +- docs/source/model_doc/deit.rst | 109 +++ docs/source/model_doc/vit.rst | 10 +- src/transformers/__init__.py | 21 + src/transformers/image_utils.py | 4 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 4 + .../models/auto/feature_extraction_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 8 +- src/transformers/models/deit/__init__.py | 72 ++ .../models/deit/configuration_deit.py | 117 +++ .../deit/convert_deit_timm_to_pytorch.py | 214 +++++ .../models/deit/feature_extraction_deit.py | 156 ++++ src/transformers/models/deit/modeling_deit.py | 770 ++++++++++++++++++ .../models/vit/convert_vit_timm_to_pytorch.py | 61 +- .../models/vit/feature_extraction_vit.py | 44 +- src/transformers/models/vit/modeling_vit.py | 4 +- src/transformers/utils/dummy_pt_objects.py | 31 + .../utils/dummy_vision_objects.py | 5 + tests/test_configuration_common.py | 6 +- tests/test_feature_extraction_deit.py | 229 ++++++ tests/test_feature_extraction_vit.py | 12 +- tests/test_modeling_deit.py | 396 +++++++++ tests/test_modeling_vit.py | 19 +- 25 files changed, 2271 insertions(+), 108 deletions(-) create mode 100644 docs/source/model_doc/deit.rst create mode 100644 src/transformers/models/deit/__init__.py create mode 100644 src/transformers/models/deit/configuration_deit.py create mode 100644 src/transformers/models/deit/convert_deit_timm_to_pytorch.py create mode 100644 src/transformers/models/deit/feature_extraction_deit.py create mode 100644 src/transformers/models/deit/modeling_deit.py create mode 100644 tests/test_feature_extraction_deit.py create mode 100644 tests/test_modeling_deit.py diff --git a/README.md b/README.md index 18b2eff45b6cdf..1e7ced945aa84a 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[CTRL](https://huggingface.co/transformers/model_doc/ctrl.html)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[DeBERTa](https://huggingface.co/transformers/model_doc/deberta.html)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. 1. **[DeBERTa-v2](https://huggingface.co/transformers/model_doc/deberta_v2.html)** (from Microsoft) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. +1. **[DeiT](https://huggingface.co/transformers/model_doc/deit.html)** (from Facebook) released with the paper [Training data-efficient image transformers & distillation through attention](https://arxiv.org/abs/2012.12877) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. 1. **[DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html)** (from Microsoft Research) released with the paper [DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation](https://arxiv.org/abs/1911.00536) by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. 1. **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace), released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into [DistilGPT2](https://github.com/huggingface/transformers/tree/master/examples/distillation), RoBERTa into [DistilRoBERTa](https://github.com/huggingface/transformers/tree/master/examples/distillation), Multilingual BERT into [DistilmBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation) and a German version of DistilBERT. 1. **[DPR](https://huggingface.co/transformers/model_doc/dpr.html)** (from Facebook) released with the paper [Dense Passage Retrieval diff --git a/docs/source/index.rst b/docs/source/index.rst index 044af02732ae6d..a2ad13949d974c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -128,119 +128,122 @@ and conversion utilities for the following models: 15. :doc:`DeBERTa-v2 ` (from Microsoft) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -16. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +16. :doc:`DeiT ` (from Facebook) released with the paper `Training data-efficient image transformers & + distillation through attention `__ by Hugo Touvron, Matthieu Cord, Matthijs + Douze, Francisco Massa, Alexandre Sablayrolles, Hervé Jégou. +17. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -17. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +18. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -18. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +19. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -19. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +20. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -20. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +21. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -21. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +22. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -22. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +23. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -23. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +24. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -24. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo +25. :doc:`GPT Neo ` (from EleutherAI) released in the repository `EleutherAI/gpt-neo `__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy. -25. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization +26. :doc:`I-BERT ` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization `__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer -26. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +27. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -27. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +28. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -28. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +29. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -29. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +30. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -30. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual +31. :doc:`M2M100 ` (from Facebook) released with the paper `Beyond English-Centric Multilingual Machine Translation `__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. -31. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +32. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -32. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +33. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -33. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible +34. :doc:`MBart-50 ` (from Facebook) released with the paper `Multilingual Translation with Extensible Multilingual Pretraining and Finetuning `__ by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan. -34. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training +35. :doc:`Megatron-BERT ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -35. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training +36. :doc:`Megatron-GPT2 ` (from NVIDIA) released with the paper `Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism `__ by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro. -36. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +37. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -37. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +38. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -38. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +39. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -39. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +40. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -40. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +41. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -41. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +42. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. -42. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper +43. :doc:`SpeechToTextTransformer ` (from Facebook), released together with the paper `fairseq S2T: Fast Speech-to-Text Modeling with fairseq `__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. -43. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +44. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -44. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +45. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -45. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +46. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -46. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +47. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -47. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 +48. :doc:`Vision Transformer (ViT) ` (from Google AI) released with the paper `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale `__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. -48. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for +49. :doc:`Wav2Vec2 ` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations `__ by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. -49. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +50. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -50. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +51. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -51. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +52. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -52. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +53. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. -53. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised +54. :doc:`XLSR-Wav2Vec2 ` (from Facebook AI) released with the paper `Unsupervised Cross-Lingual Representation Learning For Speech Recognition `__ by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. @@ -285,6 +288,8 @@ TensorFlow and/or Flax. +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | DeBERTa-v2 | ✅ | ❌ | ✅ | ❌ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +| DeiT | ❌ | ❌ | ✅ | ❌ | ❌ | ++-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | DistilBERT | ✅ | ✅ | ✅ | ✅ | ❌ | +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | ELECTRA | ✅ | ✅ | ✅ | ✅ | ❌ | @@ -447,6 +452,7 @@ TensorFlow and/or Flax. model_doc/ctrl model_doc/deberta model_doc/deberta_v2 + model_doc/deit model_doc/dialogpt model_doc/distilbert model_doc/dpr diff --git a/docs/source/model_doc/deit.rst b/docs/source/model_doc/deit.rst new file mode 100644 index 00000000000000..add47b5916e158 --- /dev/null +++ b/docs/source/model_doc/deit.rst @@ -0,0 +1,109 @@ +.. + Copyright 2021 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +DeiT +----------------------------------------------------------------------------------------------------------------------- + +.. note:: + + This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight + breaking changes to fix it in the future. If you see something strange, file a `Github Issue + `__. + + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The DeiT model was proposed in `Training data-efficient image transformers & distillation through attention +`__ by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre +Sablayrolles, Hervé Jégou. The `Vision Transformer (ViT) `__ +introduced in `Dosovitskiy et al., 2020 `__ has shown that one can match or even +outperform existing convolutional neural networks using a Transformer encoder (BERT-like). However, the ViT models +introduced in that paper required training on expensive infrastructure for multiple weeks, using external data. DeiT +(data-efficient image transformers) are more efficiently trained transformers for image classification, requiring far +less data and far less computing resources compared to the original ViT models. + +The abstract from the paper is the following: + +*Recently, neural networks purely based on attention were shown to address image understanding tasks such as image +classification. However, these visual transformers are pre-trained with hundreds of millions of images using an +expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free +transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision +transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external +data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation +token ensuring that the student learns from the teacher through attention. We show the interest of this token-based +distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets +for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and +models.* + +Tips: + +- Compared to ViT, DeiT models use a so-called distillation token to effectively learn from a teacher (which, in the + DeiT paper, is a ResNet like-model). The distillation token is learned through backpropagation, by interacting with + the class ([CLS]) and patch tokens through the self-attention layers. +- There are 2 ways to fine-tune distilled models, either (1) in a classic way, by only placing a prediction head on top + of the final hidden state of the class token and not using the distillation signal, or (2) by placing both a + prediction head on top of the class token and on top of the distillation token. In that case, the [CLS] prediction + head is trained using regular cross-entropy between the prediction of the head and the ground-truth label, while the + distillation prediction head is trained using hard distillation (cross-entropy between the prediction of the + distillation head and the label predicted by the teacher). At inference time, one takes the average prediction + between both heads as final prediction. (2) is also called "fine-tuning with distillation", because one relies on a + teacher that has already been fine-tuned on the downstream dataset. In terms of models, (1) corresponds to + :class:`~transformers.DeiTForImageClassification` and (2) corresponds to + :class:`~transformers.DeiTForImageClassificationWithTeacher`. +- Note that the authors also did try soft distillation for (2) (in which case the distillation prediction head is + trained using KL divergence to match the softmax output of the teacher), but hard distillation gave the best results. +- All released checkpoints were pre-trained and fine-tuned on ImageNet-1k only. No external data was used. This is in + contrast with the original ViT model, which used external data like the JFT-300M dataset/Imagenet-21k for + pre-training. +- The authors of DeiT also released more efficiently trained ViT models, which you can directly plug into + :class:`~transformers.ViTModel` or :class:`~transformers.ViTForImageClassification`. Techniques like data + augmentation, optimization, and regularization were used in order to simulate training on a much larger dataset + (while only using ImageNet-1k for pre-training). There are 4 variants available (in 3 different sizes): + `facebook/deit-tiny-patch16-224`, `facebook/deit-small-patch16-224`, `facebook/deit-base-patch16-224` and + `facebook/deit-base-patch16-384`. Note that one should use :class:`~transformers.DeiTFeatureExtractor` in order to + prepare images for the model. + + +DeiTConfig +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeiTConfig + :members: + + +DeiTFeatureExtractor +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeiTFeatureExtractor + :members: __call__ + + +DeiTModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeiTModel + :members: forward + + +DeiTForImageClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeiTForImageClassification + :members: forward + + +DeiTForImageClassificationWithTeacher +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeiTForImageClassificationWithTeacher + :members: forward diff --git a/docs/source/model_doc/vit.rst b/docs/source/model_doc/vit.rst index 831d4f484de74e..b747a490df54b8 100644 --- a/docs/source/model_doc/vit.rst +++ b/docs/source/model_doc/vit.rst @@ -1,5 +1,5 @@ .. - Copyright 2020 The HuggingFace Team. All rights reserved. + Copyright 2021 The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -47,10 +47,6 @@ Tips: which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be used for classification. The authors also add absolute position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. -- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to - use a higher resolution than pre-training `(Touvron et al., 2019) `__, `(Kolesnikov - et al., 2020) `__. The authors report the best results with a resolution of 384x384 - during fine-tuning. - As the Vision Transformer expects each image to be of the same size (resolution), one can use :class:`~transformers.ViTFeatureExtractor` to resize (or rescale) and normalize images for the model. - Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of @@ -61,6 +57,10 @@ Tips: 14 million images and 21k classes) only, or (2) also fine-tuned on `ImageNet `__ (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes). +- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to + use a higher resolution than pre-training `(Touvron et al., 2019) `__, `(Kolesnikov + et al., 2020) `__. In order to fine-tune at higher resolution, the authors perform + 2D interpolation of the pre-trained position embeddings, according to their location in the original image. - The best results are obtained with supervised pre-training, which is not the case in NLP. The authors also performed an experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f71e075eaaed40..3e72488be2ab44 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -167,6 +167,7 @@ "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], "models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"], + "models.deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"], "models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"], "models.dpr": [ "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -380,6 +381,7 @@ # Vision-specific objects if is_vision_available(): _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + _import_structure["models.deit"].append("DeiTFeatureExtractor") _import_structure["models.vit"].append("ViTFeatureExtractor") else: from .utils import dummy_vision_objects @@ -456,6 +458,7 @@ "load_tf_weights_in_albert", ] ) + _import_structure["models.auto"].extend( [ "MODEL_FOR_CAUSAL_LM_MAPPING", @@ -610,6 +613,15 @@ "DebertaV2PreTrainedModel", ] ) + _import_structure["models.deit"].extend( + [ + "DEIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "DeiTForImageClassification", + "DeiTForImageClassificationWithTeacher", + "DeiTModel", + "DeiTPreTrainedModel", + ] + ) _import_structure["models.distilbert"].extend( [ "DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1506,6 +1518,7 @@ from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config + from .models.deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer from .models.dpr import ( DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -1692,6 +1705,7 @@ if is_vision_available(): from .image_utils import ImageFeatureExtractionMixin + from .models.deit import DeiTFeatureExtractor from .models.vit import ViTFeatureExtractor else: from .utils.dummy_vision_objects import * @@ -1892,6 +1906,13 @@ DebertaV2Model, DebertaV2PreTrainedModel, ) + from .models.deit import ( + DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, + DeiTForImageClassification, + DeiTForImageClassificationWithTeacher, + DeiTModel, + DeiTPreTrainedModel, + ) from .models.distilbert import ( DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, DistilBertForMaskedLM, diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index fd6f31e03db3c8..add2ccac8d1d9f 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -19,6 +19,10 @@ from .file_utils import _is_torch, is_torch_available +IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] +IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] + + def is_torch_tensor(obj): return _is_torch(obj) if is_torch_available() else False diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0092c46a976768..54f1e1021781da 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -33,6 +33,7 @@ cpm, ctrl, deberta, + deit, dialogpt, distilbert, dpr, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index aa095c4e6a7849..08003a90780432 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -33,6 +33,7 @@ from ..ctrl.configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig from ..deberta.configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig from ..deberta_v2.configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config +from ..deit.configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig from ..distilbert.configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from ..dpr.configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from ..electra.configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -84,6 +85,7 @@ (key, value) for pretrained_map in [ # Add archive maps here + DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -135,6 +137,7 @@ CONFIG_MAPPING = OrderedDict( [ # Add configs here + ("deit", DeiTConfig), ("gpt_neo", GPTNeoConfig), ("big_bird", BigBirdConfig), ("speech_to_text", Speech2TextConfig), @@ -192,6 +195,7 @@ MODEL_NAMES_MAPPING = OrderedDict( [ # Add full (and cased) model names here + ("deit", "DeiT"), ("gpt_neo", "GPT Neo"), ("big_bird", "BigBird"), ("speech_to_text", "Speech2Text"), diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 097a336c96dba6..496e4d5b741a4b 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -28,14 +28,17 @@ Speech2TextFeatureExtractor = None if is_vision_available(): + from ..deit.feature_extraction_deit import DeiTFeatureExtractor from ..vit.feature_extraction_vit import ViTFeatureExtractor else: + DeiTFeatureExtractor = None ViTFeatureExtractor = None # Build the list of all feature extractors FEATURE_EXTRACTOR_MAPPING = OrderedDict( [ + ("deit", DeiTFeatureExtractor), ("s2t", Speech2TextFeatureExtractor), ("vit", ViTFeatureExtractor), ("wav2vec2", Wav2Vec2FeatureExtractor), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index cf01739296992e..f2770f4296485f 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -19,6 +19,8 @@ from collections import OrderedDict from ...utils import logging + +# Add modeling imports here from ..albert.modeling_albert import ( AlbertForMaskedLM, AlbertForMultipleChoice, @@ -95,6 +97,7 @@ DebertaV2ForTokenClassification, DebertaV2Model, ) +from ..deit.modeling_deit import DeiTForImageClassification, DeiTForImageClassificationWithTeacher, DeiTModel from ..distilbert.modeling_distilbert import ( DistilBertForMaskedLM, DistilBertForMultipleChoice, @@ -134,8 +137,6 @@ FunnelModel, ) from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model - -# Add modeling imports here from ..gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM, GPTNeoModel from ..ibert.modeling_ibert import ( IBertForMaskedLM, @@ -293,6 +294,7 @@ CTRLConfig, DebertaConfig, DebertaV2Config, + DeiTConfig, DistilBertConfig, DPRConfig, ElectraConfig, @@ -340,6 +342,7 @@ MODEL_MAPPING = OrderedDict( [ # Base model mapping + (DeiTConfig, DeiTModel), (GPTNeoConfig, GPTNeoModel), (BigBirdConfig, BigBirdModel), (Speech2TextConfig, Speech2TextModel), @@ -512,6 +515,7 @@ [ # Model for Image Classification mapping (ViTConfig, ViTForImageClassification), + (DeiTConfig, (DeiTForImageClassification, DeiTForImageClassificationWithTeacher)), ] ) diff --git a/src/transformers/models/deit/__init__.py b/src/transformers/models/deit/__init__.py new file mode 100644 index 00000000000000..255fb2626da37e --- /dev/null +++ b/src/transformers/models/deit/__init__.py @@ -0,0 +1,72 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"], +} + +if is_vision_available(): + _import_structure["feature_extraction_deit"] = ["DeiTFeatureExtractor"] + +if is_torch_available(): + _import_structure["modeling_deit"] = [ + "DEIT_PRETRAINED_MODEL_ARCHIVE_LIST", + "DeiTForImageClassification", + "DeiTForImageClassificationWithTeacher", + "DeiTModel", + "DeiTPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig + + if is_vision_available(): + from .feature_extraction_deit import DeiTFeatureExtractor + + if is_torch_available(): + from .modeling_deit import ( + DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, + DeiTForImageClassification, + DeiTForImageClassificationWithTeacher, + DeiTModel, + DeiTPreTrainedModel, + ) + + +else: + import importlib + import os + import sys + + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py new file mode 100644 index 00000000000000..0bbbff709b83f7 --- /dev/null +++ b/src/transformers/models/deit/configuration_deit.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2021 Facebook AI Research (FAIR) and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeiT model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "facebook/deit-base-distilled-patch16-224": "https://huggingface.co/facebook/deit-base-patch16-224/resolve/main/config.json", + # See all DeiT models at https://huggingface.co/models?filter=deit +} + + +class DeiTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.DeiTModel`. It is used to + instantiate an DeiT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the DeiT + `facebook/deit-base-distilled-patch16-224 `__ + architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + hidden_size (:obj:`int`, `optional`, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, + :obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): + The epsilon used by the layer normalization layers. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + image_size (:obj:`int`, `optional`, defaults to :obj:`224`): + The size (resolution) of each image. + patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): + The size (resolution) of each patch. + num_channels (:obj:`int`, `optional`, defaults to :obj:`3`): + The number of input channels. + + + Example:: + + >>> from transformers import DeiTModel, DeiTConfig + + >>> # Initializing a DeiT deit-base-distilled-patch16-224 style configuration + >>> configuration = DeiTConfig() + + >>> # Initializing a model from the deit-base-distilled-patch16-224 style configuration + >>> model = DeiTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + model_type = "deit" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + is_encoder_decoder=False, + image_size=224, + patch_size=16, + num_channels=3, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels diff --git a/src/transformers/models/deit/convert_deit_timm_to_pytorch.py b/src/transformers/models/deit/convert_deit_timm_to_pytorch.py new file mode 100644 index 00000000000000..f866b90a80df09 --- /dev/null +++ b/src/transformers/models/deit/convert_deit_timm_to_pytorch.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DeiT distilled checkpoints from the timm library.""" + + +import argparse +from pathlib import Path + +import torch +from PIL import Image + +import requests +import timm +from transformers import DeiTConfig, DeiTFeatureExtractor, DeiTForImageClassificationWithTeacher +from transformers.utils import logging +from transformers.utils.imagenet_classes import id2label + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"deit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"deit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"deit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"deit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"deit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"deit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"deit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"deit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"deit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"deit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "deit.embeddings.cls_token"), + ("dist_token", "deit.embeddings.distillation_token"), + ("patch_embed.proj.weight", "deit.embeddings.patch_embeddings.projection.weight"), + ("patch_embed.proj.bias", "deit.embeddings.patch_embeddings.projection.bias"), + ("pos_embed", "deit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ("pre_logits.fc.weight", "pooler.dense.weight"), + ("pre_logits.fc.bias", "pooler.dense.bias"), + ] + ) + + # if just the base model, we should remove "deit" from all keys that start with "deit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("deit") else pair for pair in rename_keys] + else: + # layernorm + classification heads + rename_keys.extend( + [ + ("norm.weight", "deit.layernorm.weight"), + ("norm.bias", "deit.layernorm.bias"), + ("head.weight", "cls_classifier.weight"), + ("head.bias", "cls_classifier.bias"), + ("head_dist.weight", "distillation_classifier.weight"), + ("head_dist.bias", "distillation_classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "deit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_deit_checkpoint(deit_name, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our DeiT structure. + """ + + # define default DeiT configuration + config = DeiTConfig() + # all deit models have fine-tuned heads + base_model = False + # dataset (fine-tuned on ImageNet 2012), patch_size and image_size + config.num_labels = 1000 + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + config.patch_size = int(deit_name[-6:-4]) + config.image_size = int(deit_name[-3:]) + # size of the architecture + if deit_name[9:].startswith("tiny"): + config.hidden_size = 192 + config.intermediate_size = 768 + config.num_hidden_layers = 12 + config.num_attention_heads = 3 + elif deit_name[9:].startswith("small"): + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_hidden_layers = 12 + config.num_attention_heads = 6 + if deit_name[9:].startswith("base"): + pass + elif deit_name[4:].startswith("large"): + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + + # load original model from timm + timm_model = timm.create_model(deit_name, pretrained=True) + timm_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = timm_model.state_dict() + rename_keys = create_rename_keys(config, base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + model = DeiTForImageClassificationWithTeacher(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by DeiTFeatureExtractor + size = int( + (256 / 224) * config.image_size + ) # to maintain same ratio w.r.t. 224 images, see https://github.com/facebookresearch/deit/blob/ab5715372db8c6cad5740714b2216d55aeae052e/datasets.py#L103 + feature_extractor = DeiTFeatureExtractor(size=size, crop_size=config.image_size) + encoding = feature_extractor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + timm_logits = timm_model(pixel_values) + assert timm_logits.shape == outputs.logits.shape + assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {deit_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving feature extractor to {pytorch_dump_folder_path}") + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--deit_name", + default="vit_deit_base_distilled_patch16_224", + type=str, + help="Name of the DeiT timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_deit_checkpoint(args.deit_name, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/deit/feature_extraction_deit.py b/src/transformers/models/deit/feature_extraction_deit.py new file mode 100644 index 00000000000000..aae149c40b3ee9 --- /dev/null +++ b/src/transformers/models/deit/feature_extraction_deit.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for DeiT.""" + +from typing import List, Optional, Union + +import numpy as np +from PIL import Image + +from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin +from ...file_utils import TensorType +from ...image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, ImageFeatureExtractionMixin, is_torch_tensor +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DeiTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): + r""" + Constructs a DeiT feature extractor. + + This feature extractor inherits from :class:`~transformers.FeatureExtractionMixin` which contains most of the main + methods. Users should refer to this superclass for more information regarding those methods. + + Args: + do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to resize the input to a certain :obj:`size`. + size (:obj:`int`, `optional`, defaults to 256): + Resize the input to the given size. Only has an effect if :obj:`do_resize` is set to :obj:`True`. + resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BICUBIC`): + An optional resampling filter. This can be one of :obj:`PIL.Image.NEAREST`, :obj:`PIL.Image.BOX`, + :obj:`PIL.Image.BILINEAR`, :obj:`PIL.Image.HAMMING`, :obj:`PIL.Image.BICUBIC` or :obj:`PIL.Image.LANCZOS`. + Only has an effect if :obj:`do_resize` is set to :obj:`True`. + do_center_crop (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to crop the input at the center. If the input size is smaller than :obj:`crop_size` along any edge, + the image is padded with 0's and then center cropped. + crop_size (:obj:`int`, `optional`, defaults to 224): + Desired output size when applying center-cropping. Only has an effect if :obj:`do_center_crop` is set to + :obj:`True`. + do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to normalize the input with :obj:`image_mean` and :obj:`image_std`. + image_mean (:obj:`List[int]`, defaults to :obj:`[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (:obj:`List[int]`, defaults to :obj:`[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize=True, + size=256, + resample=Image.BICUBIC, + do_center_crop=True, + crop_size=224, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs + ): + super().__init__(**kwargs) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + + def __call__( + self, + images: Union[ + Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa + ], + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> BatchFeature: + """ + Main method to prepare for the model one or several image(s). + + .. warning:: + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass + PIL images. + + Args: + images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`, defaults to :obj:`'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects. + * :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects. + + Returns: + :class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + """ + # Input type checking for clearer error + valid_images = False + + # Check that images has a valid type + if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): + valid_images = True + elif isinstance(images, (list, tuple)): + if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): + valid_images = True + + if not valid_images: + raise ValueError( + "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example)," + "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." + ) + + is_batched = bool( + isinstance(images, (list, tuple)) + and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) + ) + + if not is_batched: + images = [images] + + # transformations (resizing + center cropping + normalization) + if self.do_resize and self.size is not None and self.resample is not None: + images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] + if self.do_center_crop and self.crop_size is not None: + images = [self.center_crop(image, self.crop_size) for image in images] + if self.do_normalize: + images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] + + # return as BatchFeature + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py new file mode 100644 index 00000000000000..8844d7f656fab2 --- /dev/null +++ b/src/transformers/models/deit/modeling_deit.py @@ -0,0 +1,770 @@ +# coding=utf-8 +# Copyright 2021 Facebook AI Research (FAIR), Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeiT model. """ + + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from .configuration_deit import DeiTConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeiTConfig" + +DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/deit-base-distilled-patch16-224", + # See all DeiT models at https://huggingface.co/models?filter=deit +] + + +# Copied from transformers.models.vit.modeling_vit.to_2tuple +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + + +class DeiTEmbeddings(nn.Module): + """ + Construct the CLS token, distillation token, position and patch embeddings. + + """ + + def __init__(self, config): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.distillation_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = PatchEmbeddings( + image_size=config.image_size, + patch_size=config.patch_size, + num_channels=config.num_channels, + embed_dim=config.hidden_size, + ) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 2, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values): + batch_size = pixel_values.shape[0] + embeddings = self.patch_embeddings(pixel_values) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + distillation_tokens = self.distillation_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, distillation_tokens, embeddings), dim=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.PatchEmbeddings +class PatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + + """ + + def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768): + super().__init__() + image_size = to_2tuple(image_size) + patch_size = to_2tuple(patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + # FIXME look at relaxing size constraints + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DeiT +class DeiTSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DeiT +class DeiTSelfOutput(nn.Module): + """ + The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->DeiT +class DeiTAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DeiTSelfAttention(config) + self.output = DeiTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DeiT +class DeiTIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DeiT +class DeiTOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT +class DeiTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = DeiTAttention(config) + self.intermediate = DeiTIntermediate(config) + self.output = DeiTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in DeiT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in DeiT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + + # TODO feedforward chunking not working for now + # layer_output = apply_chunking_to_forward( + # self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layer_output + # ) + + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output) + return layer_output + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->DeiT +class DeiTEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + hidden_states, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel with ViT->DeiT all-casing +class DeiTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DeiTConfig + base_model_prefix = "deit" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +DEIT_START_DOCSTRING = r""" + This model is a PyTorch `torch.nn.Module `_ subclass. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config (:class:`~transformers.DeiTConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +DEIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + :class:`~transformers.DeiTFeatureExtractor`. See :meth:`transformers.DeiTFeatureExtractor.__call__` for + details. + + head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): + Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeiT Model transformer outputting raw hidden-states without any specific head on top.", + DEIT_START_DOCSTRING, +) +class DeiTModel(DeiTPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = DeiTEmbeddings(config) + self.encoder = DeiTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = DeiTPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples:: + + >>> from transformers import DeiTFeatureExtractor, DeiTModel + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224') + >>> model = DeiTModel.from_pretrained('facebook/deit-base-distilled-patch16-224', add_pooling_layer=False) + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DeiT +class DeiTPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@add_start_docstrings( + """ + DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + DEIT_START_DOCSTRING, +) +class DeiTForImageClassification(DeiTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.deit = DeiTModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples:: + + >>> from transformers import DeiTFeatureExtractor, DeiTForImageClassification + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here, + >>> # so the head will be randomly initialized, hence the predictions will be random + >>> feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224') + >>> model = DeiTForImageClassification.from_pretrained('facebook/deit-base-distilled-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + # we don't use the distillation token + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class DeiTForImageClassificationWithTeacherOutput(ModelOutput): + """ + Output type of :class:`~transformers.DeiTForImageClassificationWithTeacher`. + + Args: + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Prediction scores as the average of the cls_logits and distillation logits. + cls_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the + class token). + distillation_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): + Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the + distillation token). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of + each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the + weighted average in the self-attention heads. + """ + + logits: torch.FloatTensor = None + cls_logits: torch.FloatTensor = None + distillation_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + """ + DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of + the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. + + .. warning:: + + This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet + supported. + """, + DEIT_START_DOCSTRING, +) +class DeiTForImageClassificationWithTeacher(DeiTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.deit = DeiTModel(config, add_pooling_layer=False) + + # Classifier heads + self.cls_classifier = ( + nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + self.distillation_classifier = ( + nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=DeiTForImageClassificationWithTeacherOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + Returns: + + Examples:: + + >>> from transformers import DeiTFeatureExtractor, DeiTForImageClassificationWithTeacher + >>> from PIL import Image + >>> import requests + + >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg' + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> feature_extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-base-distilled-patch16-224') + >>> model = DeiTForImageClassificationWithTeacher.from_pretrained('facebook/deit-base-distilled-patch16-224') + + >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = logits.argmax(-1).item() + >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + cls_logits = self.cls_classifier(sequence_output[:, 0, :]) + distillation_logits = self.distillation_classifier(sequence_output[:, 1, :]) + + # during inference, return the average of both classifier predictions + logits = (cls_logits + distillation_logits) / 2 + + if not return_dict: + output = (logits, cls_logits, distillation_logits) + outputs[2:] + return output + + return DeiTForImageClassificationWithTeacherOutput( + logits=logits, + cls_logits=cls_logits, + distillation_logits=distillation_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py index 06b5f13446841a..88d75f6e403cc5 100644 --- a/src/transformers/models/vit/convert_vit_timm_to_pytorch.py +++ b/src/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. +# Copyright 2021 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Convert ViT checkpoints from the timm library.""" +"""Convert ViT and non-distilled DeiT checkpoints from the timm library.""" import argparse @@ -23,7 +23,7 @@ import requests import timm -from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel +from transformers import DeiTFeatureExtractor, ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel from transformers.utils import logging from transformers.utils.imagenet_classes import id2label @@ -151,23 +151,37 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): config.patch_size = int(vit_name[-6:-4]) config.image_size = int(vit_name[-3:]) # size of the architecture - if vit_name[4:].startswith("small"): - config.hidden_size = 768 - config.intermediate_size = 2304 - config.num_hidden_layers = 8 - config.num_attention_heads = 8 - if vit_name[4:].startswith("base"): - pass - elif vit_name[4:].startswith("large"): - config.hidden_size = 1024 - config.intermediate_size = 4096 - config.num_hidden_layers = 24 - config.num_attention_heads = 16 - elif vit_name[4:].startswith("huge"): - config.hidden_size = 1280 - config.intermediate_size = 5120 - config.num_hidden_layers = 32 - config.num_attention_heads = 16 + if "deit" in vit_name: + if vit_name[9:].startswith("tiny"): + config.hidden_size = 192 + config.intermediate_size = 768 + config.num_hidden_layers = 12 + config.num_attention_heads = 3 + elif vit_name[9:].startswith("small"): + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_hidden_layers = 12 + config.num_attention_heads = 6 + else: + pass + else: + if vit_name[4:].startswith("small"): + config.hidden_size = 768 + config.intermediate_size = 2304 + config.num_hidden_layers = 8 + config.num_attention_heads = 8 + elif vit_name[4:].startswith("base"): + pass + elif vit_name[4:].startswith("large"): + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif vit_name[4:].startswith("huge"): + config.hidden_size = 1280 + config.intermediate_size = 5120 + config.num_hidden_layers = 32 + config.num_attention_heads = 16 # load original model from timm timm_model = timm.create_model(vit_name, pretrained=True) @@ -189,8 +203,11 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): model = ViTForImageClassification(config).eval() model.load_state_dict(state_dict) - # Check outputs on an image, prepared by ViTFeatureExtractor - feature_extractor = ViTFeatureExtractor(size=config.image_size) + # Check outputs on an image, prepared by ViTFeatureExtractor/DeiTFeatureExtractor + if "deit" in vit_name: + feature_extractor = DeiTFeatureExtractor(size=config.image_size) + else: + feature_extractor = ViTFeatureExtractor(size=config.image_size) encoding = feature_extractor(images=prepare_img(), return_tensors="pt") pixel_values = encoding["pixel_values"] outputs = model(pixel_values) diff --git a/src/transformers/models/vit/feature_extraction_vit.py b/src/transformers/models/vit/feature_extraction_vit.py index c4cf52ebb95411..50e5d3ba3da1a8 100644 --- a/src/transformers/models/vit/feature_extraction_vit.py +++ b/src/transformers/models/vit/feature_extraction_vit.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright Google AI and The HuggingFace Inc. team. All rights reserved. +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,27 +36,41 @@ class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): methods. Users should refer to this superclass for more information regarding those methods. Args: - image_mean (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`): - The sequence of means for each channel, to be used when normalizing images. - image_std (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`): - The sequence of standard deviations for each channel, to be used when normalizing images. - do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): - Whether or not to normalize the input with mean and standard deviation. do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether to resize the input to a certain :obj:`size`. size (:obj:`int`, `optional`, defaults to 224): Resize the input to the given size. Only has an effect if :obj:`do_resize` is set to :obj:`True`. + resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`): + An optional resampling filter. This can be one of :obj:`PIL.Image.NEAREST`, :obj:`PIL.Image.BOX`, + :obj:`PIL.Image.BILINEAR`, :obj:`PIL.Image.HAMMING`, :obj:`PIL.Image.BICUBIC` or :obj:`PIL.Image.LANCZOS`. + Only has an effect if :obj:`do_resize` is set to :obj:`True`. + do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (:obj:`List[int]`, defaults to :obj:`[0.5, 0.5, 0.5]`): + The sequence of means for each channel, to be used when normalizing images. + image_std (:obj:`List[int]`, defaults to :obj:`[0.5, 0.5, 0.5]`): + The sequence of standard deviations for each channel, to be used when normalizing images. """ model_input_names = ["pixel_values"] - def __init__(self, image_mean=None, image_std=None, do_normalize=True, do_resize=True, size=224, **kwargs): + def __init__( + self, + do_resize=True, + size=224, + resample=Image.BILINEAR, + do_normalize=True, + image_mean=None, + image_std=None, + **kwargs + ): super().__init__(**kwargs) - self.image_mean = [0.5, 0.5, 0.5] - self.image_std = [0.5, 0.5, 0.5] - self.do_normalize = do_normalize self.do_resize = do_resize self.size = size + self.resample = resample + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else [0.5, 0.5, 0.5] + self.image_std = image_std if image_std is not None else [0.5, 0.5, 0.5] def __call__( self, @@ -80,12 +94,12 @@ def __call__( tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`): - If set, will return tensors instead of list of python integers. Acceptable values are: + return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`, defaults to :obj:`'np'`): + If set, will return tensors of a particular framework. Acceptable values are: * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. - * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.s + * :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects. * :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects. Returns: @@ -119,7 +133,7 @@ def __call__( # transformations (resizing + normalization) if self.do_resize and self.size is not None: - images = [self.resize(image=image, size=self.size) for image in images] + images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] if self.do_normalize: images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index b7d20ec7859c28..559dfff83c3c33 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -175,7 +175,7 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): class ViTSelfOutput(nn.Module): """ - The residual connection is defined in VitLayer instead of here (as is the case with other models), due to the + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the layernorm applied before each block. """ @@ -475,7 +475,7 @@ def forward( >>> image = Image.open(requests.get(url, stream=True).raw) >>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') - >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224') + >>> model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') >>> inputs = feature_extractor(images=image, return_tensors="pt") >>> outputs = model(**inputs) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ac8ee4d488c19d..2a24b845748a67 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1063,6 +1063,37 @@ def from_pretrained(self, *args, **kwargs): requires_backends(self, ["torch"]) +DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class DeiTForImageClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DeiTForImageClassificationWithTeacher: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DeiTModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DeiTPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 49d0f6f6c807d6..c4f55df8e8b5a3 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -7,6 +7,11 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class DeiTFeatureExtractor: + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class ViTFeatureExtractor: def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 53dbc9eeb91345..125755e06c4a16 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -20,14 +20,16 @@ class ConfigTester(object): - def __init__(self, parent, config_class=None, **kwargs): + def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs): self.parent = parent self.config_class = config_class + self.has_text_modality = has_text_modality self.inputs_dict = kwargs def create_and_test_config_common_properties(self): config = self.config_class(**self.inputs_dict) - self.parent.assertTrue(hasattr(config, "vocab_size")) + if self.has_text_modality: + self.parent.assertTrue(hasattr(config, "vocab_size")) self.parent.assertTrue(hasattr(config, "hidden_size")) self.parent.assertTrue(hasattr(config, "num_attention_heads")) self.parent.assertTrue(hasattr(config, "num_hidden_layers")) diff --git a/tests/test_feature_extraction_deit.py b/tests/test_feature_extraction_deit.py new file mode 100644 index 00000000000000..a2b60eafe6ef73 --- /dev/null +++ b/tests/test_feature_extraction_deit.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.file_utils import is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision + +from .test_feature_extraction_common import FeatureExtractionSavingTestMixin + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + from transformers import DeiTFeatureExtractor + + +class DeiTFeatureExtractionTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=20, + do_center_crop=True, + crop_size=18, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_feat_extract_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "do_center_crop": self.do_center_crop, + "crop_size": self.crop_size, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + } + + def prepare_inputs(self, equal_resolution=False, numpify=False, torchify=False): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" + + if equal_resolution: + image_inputs = [] + for i in range(self.batch_size): + image_inputs.append( + np.random.randint( + 255, size=(self.num_channels, self.max_resolution, self.max_resolution), dtype=np.uint8 + ) + ) + else: + image_inputs = [] + for i in range(self.batch_size): + width, height = np.random.choice(np.arange(self.min_resolution, self.max_resolution), 2) + image_inputs.append(np.random.randint(255, size=(self.num_channels, width, height), dtype=np.uint8)) + + if not numpify and not torchify: + # PIL expects the channel dimension as last dimension + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + if torchify: + image_inputs = [torch.from_numpy(x) for x in image_inputs] + + return image_inputs + + +@require_torch +@require_vision +class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase): + + feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None + + def setUp(self): + self.feature_extract_tester = DeiTFeatureExtractionTester(self) + + @property + def feat_extract_dict(self): + return self.feature_extract_tester.prepare_feat_extract_dict() + + def test_feat_extract_properties(self): + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + self.assertTrue(hasattr(feature_extractor, "do_resize")) + self.assertTrue(hasattr(feature_extractor, "size")) + self.assertTrue(hasattr(feature_extractor, "do_center_crop")) + self.assertTrue(hasattr(feature_extractor, "center_crop")) + self.assertTrue(hasattr(feature_extractor, "do_normalize")) + self.assertTrue(hasattr(feature_extractor, "image_mean")) + self.assertTrue(hasattr(feature_extractor, "image_std")) + + def test_batch_feature(self): + pass + + def test_call_pil(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PIL images + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + def test_call_numpy(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random numpy tensors + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + def test_call_pytorch(self): + # Initialize feature_extractor + feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) + # create random PyTorch tensors + image_inputs = self.feature_extract_tester.prepare_inputs(equal_resolution=False, torchify=True) + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + 1, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) + + # Test batched + encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values + self.assertEqual( + encoded_images.shape, + ( + self.feature_extract_tester.batch_size, + self.feature_extract_tester.num_channels, + self.feature_extract_tester.crop_size, + self.feature_extract_tester.crop_size, + ), + ) diff --git a/tests/test_feature_extraction_vit.py b/tests/test_feature_extraction_vit.py index d80b51841d0fdd..5c8db9baa63bd9 100644 --- a/tests/test_feature_extraction_vit.py +++ b/tests/test_feature_extraction_vit.py @@ -42,11 +42,11 @@ def __init__( image_size=18, min_resolution=30, max_resolution=400, - image_mean=[0.5, 0.5, 0.5], - image_std=[0.5, 0.5, 0.5], - do_normalize=True, do_resize=True, size=18, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], ): self.parent = parent self.batch_size = batch_size @@ -54,11 +54,11 @@ def __init__( self.image_size = image_size self.min_resolution = min_resolution self.max_resolution = max_resolution - self.image_mean = image_mean - self.image_std = image_std - self.do_normalize = do_normalize self.do_resize = do_resize self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std def prepare_feat_extract_dict(self): return { diff --git a/tests/test_modeling_deit.py b/tests/test_modeling_deit.py new file mode 100644 index 00000000000000..d4d95f0b4910be --- /dev/null +++ b/tests/test_modeling_deit.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch DeiT model. """ + + +import inspect +import unittest + +from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import require_torch, require_vision, slow, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + from transformers import ( + MODEL_MAPPING, + DeiTConfig, + DeiTForImageClassification, + DeiTForImageClassificationWithTeacher, + DeiTModel, + ) + from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple + + +if is_vision_available(): + from PIL import Image + + from transformers import DeiTFeatureExtractor + + +class DeiTModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=30, + patch_size=2, + num_channels=3, + is_training=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + type_sequence_label_size=10, + initializer_range=0.02, + num_labels=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + + config = DeiTConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + is_decoder=False, + initializer_range=self.initializer_range, + ) + + return config, pixel_values, labels + + def create_and_check_model(self, config, pixel_values, labels): + model = DeiTModel(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + # expected sequence length = num_patches + 2 (we add 2 for the [CLS] and distillation tokens) + image_size = to_2tuple(self.image_size) + patch_size = to_2tuple(self.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 2, self.hidden_size)) + + def create_and_check_for_image_classification(self, config, pixel_values, labels): + config.num_labels = self.type_sequence_label_size + model = DeiTForImageClassification(config) + model.to(torch_device) + model.eval() + result = model(pixel_values, labels=labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + pixel_values, + labels, + ) = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class DeiTModelTest(ModelTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as DeiT does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = ( + ( + DeiTModel, + DeiTForImageClassification, + DeiTForImageClassificationWithTeacher, + ) + if is_torch_available() + else () + ) + + test_pruning = False + test_torchscript = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = DeiTModelTester(self) + self.config_tester = ConfigTester(self, config_class=DeiTConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_inputs_embeds(self): + # DeiT does not use inputs_embeds + pass + + def test_model_common_attributes(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + self.assertIsInstance(model.get_input_embeddings(), (torch.nn.Module)) + x = model.get_output_embeddings() + self.assertTrue(x is None or isinstance(x, torch.nn.Linear)) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + # in DeiT, the seq_len equals the number of patches + 2 (we add 2 for the [CLS] and distillation tokens) + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_len = num_patches + 2 + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + # DeiT has a different seq_length + image_size = to_2tuple(self.model_tester.image_size) + patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + seq_length = num_patches + 2 + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [seq_length, self.model_tester.hidden_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + # special case for DeiTForImageClassificationWithTeacher model + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + + if return_labels: + if model_class.__name__ == "DeiTForImageClassificationWithTeacher": + del inputs_dict["labels"] + + return inputs_dict + + def test_training(self): + if not self.model_tester.is_training: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + for model_class in self.all_model_classes: + # DeiTForImageClassificationWithTeacher supports inference-only + if ( + model_class in MODEL_MAPPING.values() + or model_class.__name__ == "DeiTForImageClassificationWithTeacher" + ): + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + + def test_for_image_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_image_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in DEIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = DeiTModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +# We will verify our results on an image of cute cats +def prepare_img(): + image = Image.open("./tests/fixtures/tests_samples/COCO/cats.png") + return image + + +@require_vision +class DeiTModelIntegrationTest(unittest.TestCase): + @cached_property + def default_feature_extractor(self): + return ( + DeiTFeatureExtractor.from_pretrained("facebook/deit-base-distilled-patch16-224") + if is_vision_available() + else None + ) + + @slow + def test_inference_image_classification_head(self): + model = DeiTForImageClassificationWithTeacher.from_pretrained("facebook/deit-base-distilled-patch16-224").to( + torch_device + ) + + feature_extractor = self.default_feature_extractor + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + outputs = model(**inputs) + + # verify the logits + expected_shape = torch.Size((1, 1000)) + self.assertEqual(outputs.logits.shape, expected_shape) + + expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device) + + self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) diff --git a/tests/test_modeling_vit.py b/tests/test_modeling_vit.py index ec060c9da68e13..b5436b7dc0e779 100644 --- a/tests/test_modeling_vit.py +++ b/tests/test_modeling_vit.py @@ -155,20 +155,10 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = ViTModelTester(self) - self.config_tester = ConfigTester(self, config_class=ViTConfig, hidden_size=37) + self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37) def test_config(self): - config = self.config_tester.config_class(**self.config_tester.inputs_dict) - # we omit vocab_size since ViT does not use this - self.config_tester.parent.assertTrue(hasattr(config, "hidden_size")) - self.config_tester.parent.assertTrue(hasattr(config, "num_attention_heads")) - self.config_tester.parent.assertTrue(hasattr(config, "num_hidden_layers")) - - self.config_tester.create_and_test_config_to_json_string() - self.config_tester.create_and_test_config_to_json_file() - self.config_tester.create_and_test_config_from_and_save_pretrained() - self.config_tester.create_and_test_config_with_num_labels() - self.config_tester.check_config_can_be_init_without_params() + self.config_tester.run_common_tests() def test_inputs_embeds(self): # ViT does not use inputs_embeds @@ -351,10 +341,7 @@ def test_inference_image_classification_head(self): inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) # forward pass - # currently failing - # see https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-double-but-got-scalar-type-float-for-argument-2-weight/38961/2 - outputs = model(inputs["pixel_values"]) - # outputs = model(**inputs) + outputs = model(**inputs) # verify the logits expected_shape = torch.Size((1, 1000))