diff --git a/README.md b/README.md index d94ef9d98cd888..da9de18606b9e2 100644 --- a/README.md +++ b/README.md @@ -491,6 +491,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (from Meta AI) released with the paper [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) by Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (from HUST-VL) rreleased with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/README_es.md b/README_es.md index 6bcc369320fad3..04d88e9a375eaa 100644 --- a/README_es.md +++ b/README_es.md @@ -468,6 +468,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (from Meta AI) released with the paper [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) by Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/README_hd.md b/README_hd.md index d9e00af482581f..53fb0f7a32337d 100644 --- a/README_hd.md +++ b/README_hd.md @@ -440,6 +440,7 @@ conda install -c huggingface transformers 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (Meta AI से) Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. द्वाराअनुसंधान पत्र [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) के साथ जारी किया गया 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (मेटा एआई से) साथ में कागज [मास्कड ऑटोएन्कोडर स्केलेबल विजन लर्नर्स हैं](https://arxiv.org/ एब्स/2111.06377) कैमिंग हे, ज़िनेली चेन, सेनिंग ज़ी, यांगहो ली, पिओट्र डॉलर, रॉस गिर्शिक द्वारा। +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (HUST-VL से) Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. द्वाराअनुसंधान पत्र [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) के साथ जारी किया गया 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (मेटा एआई से) साथ में कागज [लेबल-कुशल सीखने के लिए मास्क्ड स्याम देश के नेटवर्क](https://arxiv. org/abs/2204.07141) महमूद असरान, मथिल्डे कैरन, ईशान मिश्रा, पियोट्र बोजानोवस्की, फ्लोरियन बोर्डेस, पास्कल विंसेंट, आर्मंड जौलिन, माइकल रब्बत, निकोलस बल्लास द्वारा। 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (Kakao Enterprise से) Jaehyeon Kim, Jungil Kong, Juhee Son. द्वाराअनुसंधान पत्र [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) के साथ जारी किया गया 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/README_ja.md b/README_ja.md index a7be7956a323df..57f2b83adaeedb 100644 --- a/README_ja.md +++ b/README_ja.md @@ -502,6 +502,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (Google AI から) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby から公開された研究論文: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (Meta AI から) Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. から公開された研究論文 [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (Meta AI から) Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick から公開された研究論文: [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (HUST-VL から) Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. から公開された研究論文 [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (Meta AI から) Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas から公開された研究論文: [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (Kakao Enterprise から) Jaehyeon Kim, Jungil Kong, Juhee Son. から公開された研究論文 [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/README_ko.md b/README_ko.md index f5b4dc5eb44a26..cf50289b405b80 100644 --- a/README_ko.md +++ b/README_ko.md @@ -417,6 +417,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (Google AI 에서) Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 의 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 논문과 함께 발표했습니다. 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (Meta AI 에서 제공)은 Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He.의 [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527)논문과 함께 발표했습니다. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (Meta AI 에서) Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick 의 [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) 논문과 함께 발표했습니다. +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (HUST-VL 에서 제공)은 Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.의 [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272)논문과 함께 발표했습니다. 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (Meta AI 에서) Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas 의 [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) 논문과 함께 발표했습니다. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (Kakao Enterprise 에서 제공)은 Jaehyeon Kim, Jungil Kong, Juhee Son.의 [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103)논문과 함께 발표했습니다. 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/README_zh-hans.md b/README_zh-hans.md index 56f929958aad21..af986fa7248788 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -441,6 +441,7 @@ conda install -c huggingface transformers 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (来自 Google AI) 伴随论文 [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) 由 Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby 发布。 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (来自 Meta AI) 伴随论文 [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) 由 Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He 发布。 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (来自 Meta AI) 伴随论文 [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) 由 Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick 发布。 +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (来自 HUST-VL) 伴随论文 [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) 由 Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang 发布。 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (来自 Meta AI) 伴随论文 [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas 发布. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (来自 Kakao Enterprise) 伴随论文 [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) 由 Jaehyeon Kim, Jungil Kong, Juhee Son 发布。 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (来自 Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) 由 Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/README_zh-hant.md b/README_zh-hant.md index 36d134b1f53d0b..26bd0cd91b888d 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -453,6 +453,7 @@ conda install -c huggingface transformers 1. **[ViT Hybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VitDet](https://huggingface.co/docs/transformers/model_doc/vitdet)** (from Meta AI) released with the paper [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) by Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. 1. **[ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMatte](https://huggingface.co/docs/transformers/main/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[ViViT](https://huggingface.co/docs/transformers/model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 5c847c30a04041..83d87270aaa22f 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -573,6 +573,8 @@ title: ViTDet - local: model_doc/vit_mae title: ViTMAE + - local: model_doc/vitmatte + title: ViTMatte - local: model_doc/vit_msn title: ViTMSN - local: model_doc/vivit diff --git a/docs/source/en/index.md b/docs/source/en/index.md index d32112c5fed124..cb1ab70fd4f8b7 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -257,6 +257,7 @@ The documentation is organized into five sections: 1. **[ViT Hybrid](model_doc/vit_hybrid)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[VitDet](model_doc/vitdet)** (from Meta AI) released with the paper [Exploring Plain Vision Transformer Backbones for Object Detection](https://arxiv.org/abs/2203.16527) by Yanghao Li, Hanzi Mao, Ross Girshick, Kaiming He. 1. **[ViTMAE](model_doc/vit_mae)** (from Meta AI) released with the paper [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377) by Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, Ross Girshick. +1. **[ViTMatte](model_doc/vitmatte)** (from HUST-VL) rreleased with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. 1. **[ViTMSN](model_doc/vit_msn)** (from Meta AI) released with the paper [Masked Siamese Networks for Label-Efficient Learning](https://arxiv.org/abs/2204.07141) by Mahmoud Assran, Mathilde Caron, Ishan Misra, Piotr Bojanowski, Florian Bordes, Pascal Vincent, Armand Joulin, Michael Rabbat, Nicolas Ballas. 1. **[VITS](model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[ViViT](model_doc/vivit)** (from Google Research) released with the paper [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) by Anurag Arnab, Mostafa Dehghani, Georg Heigold, Chen Sun, Mario Lučić, Cordelia Schmid. @@ -480,6 +481,7 @@ Flax), PyTorch, and/or TensorFlow. | ViT Hybrid | ✅ | ❌ | ❌ | | VitDet | ✅ | ❌ | ❌ | | ViTMAE | ✅ | ✅ | ❌ | +| ViTMatte | ✅ | ❌ | ❌ | | ViTMSN | ✅ | ❌ | ❌ | | VITS | ✅ | ❌ | ❌ | | ViViT | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/vitmatte.md b/docs/source/en/model_doc/vitmatte.md new file mode 100644 index 00000000000000..d3e682686d8212 --- /dev/null +++ b/docs/source/en/model_doc/vitmatte.md @@ -0,0 +1,44 @@ + + +# VitMatte + +## Overview + +The VitMatte model was proposed in [Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. +VitMatte leverages plain [Vision Transformers](vit) for the task of image matting, which is the process of accurately estimating the foreground object in images and videos. + +The abstract from the paper is the following: + +*Recently, plain vision Transformers (ViTs) have shown impressive performance on various computer vision tasks, thanks to their strong modeling capacity and large-scale pretraining. However, they have not yet conquered the problem of image matting. We hypothesize that image matting could also be boosted by ViTs and present a new efficient and robust ViT-based matting system, named ViTMatte. Our method utilizes (i) a hybrid attention mechanism combined with a convolution neck to help ViTs achieve an excellent performance-computation trade-off in matting tasks. (ii) Additionally, we introduce the detail capture module, which just consists of simple lightweight convolutions to complement the detailed information required by matting. To the best of our knowledge, ViTMatte is the first work to unleash the potential of ViT on image matting with concise adaptation. It inherits many superior properties from ViT to matting, including various pretraining strategies, concise architecture design, and flexible inference strategies. We evaluate ViTMatte on Composition-1k and Distinctions-646, the most commonly used benchmark for image matting, our method achieves state-of-the-art performance and outperforms prior matting works by a large margin.* + +Tips: + +- The model expects both the image and trimap (concatenated) as input. One can use [`ViTMatteImageProcessor`] for this purpose. + +This model was contributed by [nielsr](https://huggingface.co/nielsr). +The original code can be found [here](https://github.com/hustvl/ViTMatte). + + +## VitMatteConfig + +[[autodoc]] VitMatteConfig + +## VitMatteImageProcessor + +[[autodoc]] VitMatteImageProcessor + - preprocess + +## VitMatteForImageMatting + +[[autodoc]] VitMatteForImageMatting + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aa26b914bd53a0..cd06fd001f290d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -589,6 +589,7 @@ "models.vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"], "models.vit_msn": ["VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMSNConfig"], "models.vitdet": ["VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP", "VitDetConfig"], + "models.vitmatte": ["VITMATTE_PRETRAINED_CONFIG_ARCHIVE_MAP", "VitMatteConfig"], "models.vits": [ "VITS_PRETRAINED_CONFIG_ARCHIVE_MAP", "VitsConfig", @@ -985,6 +986,7 @@ _import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"]) _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"]) _import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"]) + _import_structure["models.vitmatte"].append("VitMatteImageProcessor") _import_structure["models.vivit"].append("VivitImageProcessor") _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"]) @@ -2957,6 +2959,13 @@ "VitDetPreTrainedModel", ] ) + _import_structure["models.vitmatte"].extend( + [ + "VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST", + "VitMatteForImageMatting", + "VitMattePreTrainedModel", + ] + ) _import_structure["models.vits"].extend( [ "VITS_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4675,6 +4684,7 @@ from .models.vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig from .models.vit_msn import VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMSNConfig from .models.vitdet import VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP, VitDetConfig + from .models.vitmatte import VITMATTE_PRETRAINED_CONFIG_ARCHIVE_MAP, VitMatteConfig from .models.vits import ( VITS_PRETRAINED_CONFIG_ARCHIVE_MAP, VitsConfig, @@ -5029,6 +5039,7 @@ from .models.vilt import ViltFeatureExtractor, ViltImageProcessor, ViltProcessor from .models.vit import ViTFeatureExtractor, ViTImageProcessor from .models.vit_hybrid import ViTHybridImageProcessor + from .models.vitmatte import VitMatteImageProcessor from .models.vivit import VivitImageProcessor from .models.yolos import YolosFeatureExtractor, YolosImageProcessor @@ -6648,6 +6659,11 @@ VitDetModel, VitDetPreTrainedModel, ) + from .models.vitmatte import ( + VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST, + VitMatteForImageMatting, + VitMattePreTrainedModel, + ) from .models.vits import ( VITS_PRETRAINED_MODEL_ARCHIVE_LIST, VitsModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 1cc44c316b9869..a62e0fed1e2b27 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -213,6 +213,7 @@ vit_mae, vit_msn, vitdet, + vitmatte, vits, vivit, wav2vec2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 027d22bd7a33f5..6f9663edd359ce 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -221,6 +221,7 @@ ("vit_mae", "ViTMAEConfig"), ("vit_msn", "ViTMSNConfig"), ("vitdet", "VitDetConfig"), + ("vitmatte", "VitMatteConfig"), ("vits", "VitsConfig"), ("vivit", "VivitConfig"), ("wav2vec2", "Wav2Vec2Config"), @@ -415,6 +416,7 @@ ("vit_mae", "VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vit_msn", "VIT_MSN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vitdet", "VITDET_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("vitmatte", "VITMATTE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vits", "VITS_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("vivit", "VIVIT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("wav2vec2", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -651,6 +653,7 @@ ("vit_mae", "ViTMAE"), ("vit_msn", "ViTMSN"), ("vitdet", "VitDet"), + ("vitmatte", "ViTMatte"), ("vits", "VITS"), ("vivit", "ViViT"), ("wav2vec2", "Wav2Vec2"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index a09b29dce56e73..21817e58a3a8b9 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -107,6 +107,7 @@ ("vit_hybrid", "ViTHybridImageProcessor"), ("vit_mae", "ViTImageProcessor"), ("vit_msn", "ViTImageProcessor"), + ("vitmatte", "VitMatteImageProcessor"), ("xclip", "CLIPImageProcessor"), ("yolos", "YolosImageProcessor"), ] diff --git a/src/transformers/models/vitmatte/__init__.py b/src/transformers/models/vitmatte/__init__.py new file mode 100644 index 00000000000000..abbfae97c22030 --- /dev/null +++ b/src/transformers/models/vitmatte/__init__.py @@ -0,0 +1,72 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_vision_available, +) + + +_import_structure = {"configuration_vitmatte": ["VITMATTE_PRETRAINED_CONFIG_ARCHIVE_MAP", "VitMatteConfig"]} + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_vitmatte"] = ["VitMatteImageProcessor"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_vitmatte"] = [ + "VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST", + "VitMattePreTrainedModel", + "VitMatteForImageMatting", + ] + +if TYPE_CHECKING: + from .configuration_vitmatte import VITMATTE_PRETRAINED_CONFIG_ARCHIVE_MAP, VitMatteConfig + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_vitmatte import VitMatteImageProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_vitmatte import ( + VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST, + VitMatteForImageMatting, + VitMattePreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/vitmatte/configuration_vitmatte.py b/src/transformers/models/vitmatte/configuration_vitmatte.py new file mode 100644 index 00000000000000..cbbe30d9c9e0a2 --- /dev/null +++ b/src/transformers/models/vitmatte/configuration_vitmatte.py @@ -0,0 +1,107 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" VitMatte model configuration""" + +import copy +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + +VITMATTE_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "hustvl/vitmatte-small-composition-1k": "https://huggingface.co/hustvl/vitmatte-small-composition-1k/resolve/main/config.json", +} + + +class VitMatteConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of [`VitMatteForImageMatting`]. It is used to + instantiate a ViTMatte model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ViTMatte + [hustvl/vitmatte-small-composition-1k](https://huggingface.co/hustvl/vitmatte-small-composition-1k) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`): + The configuration of the backbone model. + hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the decoder. + batch_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the batch norm layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + convstream_hidden_sizes (`List[int]`, *optional*, defaults to `[48, 96, 192]`): + The output channels of the ConvStream module. + fusion_hidden_sizes (`List[int]`, *optional*, defaults to `[256, 128, 64, 32]`): + The output channels of the Fusion blocks. + + Example: + + ```python + >>> from transformers import VitMatteConfig, VitMatteForImageMatting + + >>> # Initializing a ViTMatte hustvl/vitmatte-small-composition-1k style configuration + >>> configuration = VitMatteConfig() + + >>> # Initializing a model (with random weights) from the hustvl/vitmatte-small-composition-1k style configuration + >>> model = VitMatteForImageMatting(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "vitmatte" + + def __init__( + self, + backbone_config: PretrainedConfig = None, + hidden_size: int = 384, + batch_norm_eps: float = 1e-5, + initializer_range: float = 0.02, + convstream_hidden_sizes: List[int] = [48, 96, 192], + fusion_hidden_sizes: List[int] = [256, 128, 64, 32], + **kwargs, + ): + super().__init__(**kwargs) + + if backbone_config is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.") + backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + self.batch_norm_eps = batch_norm_eps + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.convstream_hidden_sizes = convstream_hidden_sizes + self.fusion_hidden_sizes = fusion_hidden_sizes + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["backbone_config"] = self.backbone_config.to_dict() + output["model_type"] = self.__class__.model_type + return output diff --git a/src/transformers/models/vitmatte/convert_vitmatte_to_hf.py b/src/transformers/models/vitmatte/convert_vitmatte_to_hf.py new file mode 100644 index 00000000000000..bcc05563337198 --- /dev/null +++ b/src/transformers/models/vitmatte/convert_vitmatte_to_hf.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert VitMatte checkpoints from the original repository. + +URL: https://github.com/hustvl/ViTMatte +""" + +import argparse + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import VitDetConfig, VitMatteConfig, VitMatteForImageMatting, VitMatteImageProcessor + + +def get_config(model_name): + hidden_size = 384 if "small" in model_name else 768 + num_attention_heads = 6 if "small" in model_name else 12 + + backbone_config = VitDetConfig( + num_channels=4, + image_size=512, + pretrain_image_size=224, + patch_size=16, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_absolute_position_embeddings=True, + use_relative_position_embeddings=True, + window_size=14, + # 2, 5, 8, 11 for global attention + window_block_indices=[0, 1, 3, 4, 6, 7, 9, 10], + residual_block_indices=[2, 5, 8, 11], + out_features=["stage12"], + ) + + return VitMatteConfig(backbone_config=backbone_config, hidden_size=hidden_size) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("backbone.pos_embed", "backbone.embeddings.position_embeddings")) + rename_keys.append(("backbone.patch_embed.proj.weight", "backbone.embeddings.projection.weight")) + rename_keys.append(("backbone.patch_embed.proj.bias", "backbone.embeddings.projection.bias")) + # fmt: on + + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def convert_vitmatte_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): + config = get_config(model_name) + + # load original state dict + model_name_to_filename = { + "vitmatte-small-composition-1k": "ViTMatte_S_Com.pth", + "vitmatte-base-composition-1k": "ViTMatte_B_Com.pth", + "vitmatte-small-distinctions-646": "ViTMatte_S_DIS.pth", + "vitmatte-base-distinctions-646": "ViTMatte_B_DIS.pth", + } + + filename = model_name_to_filename[model_name] + filepath = hf_hub_download(repo_id="nielsr/vitmatte-checkpoints", filename=filename, repo_type="model") + state_dict = torch.load(filepath, map_location="cpu") + + # rename keys + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + if "backbone.blocks" in key: + key = key.replace("backbone.blocks", "backbone.encoder.layer") + if "attn" in key: + key = key.replace("attn", "attention") + if "fusion_blks" in key: + key = key.replace("fusion_blks", "fusion_blocks") + if "bn" in key: + key = key.replace("bn", "batch_norm") + state_dict[key] = val + + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + + # create model + processor = VitMatteImageProcessor() + model = VitMatteForImageMatting(config) + model.eval() + + # load state dict + model.load_state_dict(state_dict) + + # verify on dummy image + trimap + url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_rgb.png?raw=true" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_trimap.png?raw=true" + trimap = Image.open(requests.get(url, stream=True).raw) + + pixel_values = processor(images=image, trimaps=trimap.convert("L"), return_tensors="pt").pixel_values + + with torch.no_grad(): + alphas = model(pixel_values).alphas + + if model_name == "vitmatte-small-composition-1k": + expected_slice = torch.tensor([[0.9977, 0.9987, 0.9990], [0.9980, 0.9998, 0.9998], [0.9983, 0.9998, 0.9998]]) + elif model_name == "vitmatte-base-composition-1k": + expected_slice = torch.tensor([[0.9972, 0.9971, 0.9981], [0.9948, 0.9987, 0.9994], [0.9963, 0.9992, 0.9995]]) + elif model_name == "vitmatte-small-distinctions-646": + expected_slice = torch.tensor([[0.9880, 0.9970, 0.9972], [0.9960, 0.9996, 0.9997], [0.9963, 0.9996, 0.9997]]) + elif model_name == "vitmatte-base-distinctions-646": + expected_slice = torch.tensor([[0.9963, 0.9998, 0.9999], [0.9995, 1.0000, 1.0000], [0.9992, 0.9999, 1.0000]]) + + assert torch.allclose(alphas[0, 0, :3, :3], expected_slice, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor of {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to hub") + model.push_to_hub(f"hustvl/{model_name}") + processor.push_to_hub(f"hustvl/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="vitmatte-small-composition-1k", + type=str, + choices=[ + "vitmatte-small-composition-1k", + "vitmatte-base-composition-1k", + "vitmatte-small-distinctions-646", + "vitmatte-base-distinctions-646", + ], + help="Name of the VitMatte model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_vitmatte_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte.py b/src/transformers/models/vitmatte/image_processing_vitmatte.py new file mode 100644 index 00000000000000..a0bd940b80b133 --- /dev/null +++ b/src/transformers/models/vitmatte/image_processing_vitmatte.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for ViTMatte.""" + +from typing import List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import pad, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class VitMatteImageProcessor(BaseImageProcessor): + r""" + Constructs a ViTMatte image processor. + + Args: + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to 1/255): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden + by the `do_pad` parameter in the `preprocess` method. + size_divisibility (`int`, *optional*, defaults to 32): + The width and height of the image will be padded to be divisible by this number. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = True, + size_divisibility: int = 32, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.do_pad = do_pad + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.size_divisibility = size_divisibility + + def pad_image( + self, + image: np.ndarray, + size_divisibility: int = 32, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Args: + image (`np.ndarray`): + Image to pad. + size_divisibility (`int`, *optional*, defaults to 32): + The width and height of the image will be padded to be divisible by this number. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + height, width = get_image_size(image, input_data_format) + + if height % size_divisibility != 0 or width % size_divisibility != 0: + pad_height = size_divisibility - height % size_divisibility + pad_width = size_divisibility - width % size_divisibility + padding = ((0, pad_height), (0, pad_width)) + image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_data_format) + + return image + + def preprocess( + self, + images: ImageInput, + trimaps: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + trimaps (`ImageInput`): + Trimap to preprocess. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. + size_divisibility (`int`, *optional*, defaults to `self.size_divisibility`): + The size divisibility to pad the image to if `do_pad` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_pad = do_pad if do_pad is not None else self.do_pad + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + size_divisibility = size_divisibility if size_divisibility is not None else self.size_divisibility + + images = make_list_of_images(images) + trimaps = make_list_of_images(trimaps, expected_ndims=2) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + if not valid_images(trimaps): + raise ValueError( + "Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_pad and size_divisibility is None: + raise ValueError("Size divisilibyt must be specified if do_pad is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + trimaps = [to_numpy_array(trimap) for trimap in trimaps] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + trimaps = [ + self.rescale(image=trimap, scale=rescale_factor, input_data_format=input_data_format) + for trimap in trimaps + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + # concatenate images and trimaps + images = [ + np.concatenate([image, np.expand_dims(trimap, axis=-1)], axis=-1) for image, trimap in zip(images, trimaps) + ] + + if do_pad: + images = [ + self.pad_image(image, size_divisibility=size_divisibility, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image=image, channel_dim=data_format, input_channel_dim=input_data_format) + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py new file mode 100644 index 00000000000000..b23bdd21d56b85 --- /dev/null +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -0,0 +1,343 @@ +# coding=utf-8 +# Copyright 2023 HUST-VL and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ViTMatte model.""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from ... import AutoBackbone +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_vitmatte import VitMatteConfig + + +VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "hustvl/vitmatte-small-composition-1k", + # See all VitMatte models at https://huggingface.co/models?filter=vitmatte +] + + +# General docstring +_CONFIG_FOR_DOC = "VitMatteConfig" + + +@dataclass +class ImageMattingOutput(ModelOutput): + """ + Class for outputs of image matting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Loss. + alphas (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Estimated alpha values. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + alphas: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class VitMattePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VitMatteConfig + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BackboneMixin): + module.gradient_checkpointing = value + + +class VitMatteBasicConv3x3(nn.Module): + """ + Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. + """ + + def __init__(self, config, in_channels, out_channels, stride=2, padding=1): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=False, + ) + self.batch_norm = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps) + self.relu = nn.ReLU() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.batch_norm(hidden_state) + hidden_state = self.relu(hidden_state) + + return hidden_state + + +class VitMatteConvStream(nn.Module): + """ + Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. + """ + + def __init__(self, config): + super().__init__() + + in_channels = config.backbone_config.num_channels + out_channels = config.convstream_hidden_sizes + + self.convs = nn.ModuleList() + self.conv_chans = [in_channels] + out_channels + + for i in range(len(self.conv_chans) - 1): + in_chan_ = self.conv_chans[i] + out_chan_ = self.conv_chans[i + 1] + self.convs.append(VitMatteBasicConv3x3(config, in_chan_, out_chan_)) + + def forward(self, pixel_values): + out_dict = {"detailed_feature_map_0": pixel_values} + embeddings = pixel_values + for i in range(len(self.convs)): + embeddings = self.convs[i](embeddings) + name_ = "detailed_feature_map_" + str(i + 1) + out_dict[name_] = embeddings + + return out_dict + + +class VitMatteFusionBlock(nn.Module): + """ + Simple fusion block to fuse features from ConvStream and Plain Vision Transformer. + """ + + def __init__(self, config, in_channels, out_channels): + super().__init__() + self.conv = VitMatteBasicConv3x3(config, in_channels, out_channels, stride=1, padding=1) + + def forward(self, features, detailed_feature_map): + upscaled_features = nn.functional.interpolate(features, scale_factor=2, mode="bilinear", align_corners=False) + out = torch.cat([detailed_feature_map, upscaled_features], dim=1) + out = self.conv(out) + + return out + + +class VitMatteHead(nn.Module): + """ + Simple Matting Head, containing only conv3x3 and conv1x1 layers. + """ + + def __init__(self, config): + super().__init__() + + in_channels = config.fusion_hidden_sizes[-1] + mid_channels = 16 + + self.matting_convs = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(mid_channels), + nn.ReLU(True), + nn.Conv2d(mid_channels, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, hidden_state): + hidden_state = self.matting_convs(hidden_state) + + return hidden_state + + +class VitMatteDetailCaptureModule(nn.Module): + """ + Simple and lightweight Detail Capture Module for ViT Matting. + """ + + def __init__(self, config): + super().__init__() + if len(config.fusion_hidden_sizes) != len(config.convstream_hidden_sizes) + 1: + raise ValueError( + "The length of fusion_hidden_sizes should be equal to the length of convstream_hidden_sizes + 1." + ) + + self.config = config + self.convstream = VitMatteConvStream(config) + self.conv_chans = self.convstream.conv_chans + + self.fusion_blocks = nn.ModuleList() + self.fusion_channels = [config.hidden_size] + config.fusion_hidden_sizes + + for i in range(len(self.fusion_channels) - 1): + self.fusion_blocks.append( + VitMatteFusionBlock( + config=config, + in_channels=self.fusion_channels[i] + self.conv_chans[-(i + 1)], + out_channels=self.fusion_channels[i + 1], + ) + ) + + self.matting_head = VitMatteHead(config) + + def forward(self, features, pixel_values): + detail_features = self.convstream(pixel_values) + for i in range(len(self.fusion_blocks)): + detailed_feature_map_name = "detailed_feature_map_" + str(len(self.fusion_blocks) - i - 1) + features = self.fusion_blocks[i](features, detail_features[detailed_feature_map_name]) + + alphas = torch.sigmoid(self.matting_head(features)) + + return alphas + + +VITMATTE_START_DOCSTRING = r""" + Parameters: + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + config ([`UperNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VITMATTE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`VitMatteImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers in case the backbone has them. See + `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers of the backbone. See `hidden_states` under + returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes.""", + VITMATTE_START_DOCSTRING, +) +class VitMatteForImageMatting(VitMattePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.backbone = AutoBackbone.from_config(config.backbone_config) + self.decoder = VitMatteDetailCaptureModule(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VITMATTE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=ImageMattingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ): + """ + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth image matting for computing the loss. + + Returns: + + Examples: + + ```python + >>> from transformers import VitMatteImageProcessor, VitMatteForImageMatting + >>> import torch + >>> from PIL import Image + >>> from huggingface_hub import hf_hub_download + + >>> processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k") + >>> model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k") + + >>> filepath = hf_hub_download( + ... repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset" + ... ) + >>> image = Image.open(filepath).convert("RGB") + >>> filepath = hf_hub_download( + ... repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset" + ... ) + >>> trimap = Image.open(filepath).convert("L") + + >>> # prepare image + trimap for the model + >>> inputs = processor(images=image, trimaps=trimap, return_tensors="pt") + + >>> with torch.no_grad(): + ... alphas = model(**inputs).alphas + >>> print(alphas.shape) + torch.Size([1, 1, 640, 960]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + + features = outputs.feature_maps[-1] + alphas = self.decoder(features, pixel_values) + + loss = None + if labels is not None: + raise NotImplementedError("Training is not yet supported") + + if not return_dict: + output = (alphas,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageMattingOutput( + loss=loss, + alphas=alphas, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index a564bf7a55fef3..215ba5647b200b 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8002,6 +8002,23 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class VitMatteForImageMatting(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class VitMattePreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + VITS_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 134a853eec3f37..c1a1d1d8542d9d 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -513,6 +513,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class VitMatteImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class VivitImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/vitmatte/__init__.py b/tests/models/vitmatte/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/vitmatte/test_image_processing_vitmatte.py b/tests/models/vitmatte/test_image_processing_vitmatte.py new file mode 100644 index 00000000000000..e1009c75928320 --- /dev/null +++ b/tests/models/vitmatte/test_image_processing_vitmatte.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + from transformers import VitMatteImageProcessor + + +class VitMatteImageProcessingTester(unittest.TestCase): + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_rescale=True, + rescale_factor=0.5, + do_pad=True, + size_divisibility=10, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_pad = do_pad + self.size_divisibility = size_divisibility + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + + def prepare_image_processor_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_normalize": self.do_normalize, + "do_rescale": self.do_rescale, + "rescale_factor": self.rescale_factor, + "do_pad": self.do_pad, + "size_divisibility": self.size_divisibility, + } + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = VitMatteImageProcessor if is_vision_available() else None + + def setUp(self): + self.image_processor_tester = VitMatteImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_pad")) + self.assertTrue(hasattr(image_processing, "size_divisibility")) + + def test_call_numpy(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input (image processor does not support batched inputs) + image = image_inputs[0] + trimap = np.random.randint(0, 3, size=image.shape[:2]) + encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values + + # Verify that width and height can be divided by size_divisibility + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + + def test_call_pytorch(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input (image processor does not support batched inputs) + image = image_inputs[0] + trimap = np.random.randint(0, 3, size=image.shape[:2]) + encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values + + # Verify that width and height can be divided by size_divisibility + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + + def test_call_pil(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input (image processor does not support batched inputs) + image = image_inputs[0] + trimap = np.random.randint(0, 3, size=image.size[::-1]) + encoded_images = image_processing(images=image, trimaps=trimap, return_tensors="pt").pixel_values + + # Verify that width and height can be divided by size_divisibility + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + + def test_call_numpy_4_channels(self): + # Test that can process images which have an arbitrary number of channels + # Initialize image_processing + image_processor = self.image_processing_class(**self.image_processor_dict) + + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + # Test not batched input (image processor does not support batched inputs) + image = image_inputs[0] + trimap = np.random.randint(0, 3, size=image.shape[:2]) + encoded_images = image_processor( + images=image, + trimaps=trimap, + input_data_format="channels_first", + image_mean=0, + image_std=1, + return_tensors="pt", + ).pixel_values + + # Verify that width and height can be divided by size_divisibility + self.assertTrue(encoded_images.shape[-1] % self.image_processor_tester.size_divisibility == 0) + self.assertTrue(encoded_images.shape[-2] % self.image_processor_tester.size_divisibility == 0) + + def test_padding(self): + image_processing = self.image_processing_class(**self.image_processor_dict) + image = np.random.randn(3, 249, 491) + images = image_processing.pad_image(image) + assert images.shape == (3, 256, 512) diff --git a/tests/models/vitmatte/test_modeling_vitmatte.py b/tests/models/vitmatte/test_modeling_vitmatte.py new file mode 100644 index 00000000000000..09e3f60966b0ad --- /dev/null +++ b/tests/models/vitmatte/test_modeling_vitmatte.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch VitMatte model. """ + + +import inspect +import unittest + +from huggingface_hub import hf_hub_download + +from transformers import VitMatteConfig +from transformers.testing_utils import ( + require_torch, + slow, + torch_device, +) +from transformers.utils import is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import VitDetConfig, VitMatteForImageMatting + from transformers.models.vitmatte.modeling_vitmatte import VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + from transformers import VitMatteImageProcessor + + +class VitMatteModelTester: + def __init__( + self, + parent, + batch_size=13, + image_size=32, + patch_size=16, + num_channels=4, + is_training=True, + use_labels=False, + hidden_size=2, + num_hidden_layers=2, + num_attention_heads=2, + hidden_act="gelu", + type_sequence_label_size=10, + initializer_range=0.02, + scope=None, + out_features=["stage1"], + fusion_hidden_sizes=[128, 64, 32, 16], + ): + self.parent = parent + self.batch_size = batch_size + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.is_training = is_training + self.use_labels = use_labels + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.scope = scope + self.out_features = out_features + self.fusion_hidden_sizes = fusion_hidden_sizes + + self.seq_length = (self.image_size // self.patch_size) ** 2 + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + + labels = None + if self.use_labels: + raise NotImplementedError("Training is not yet supported") + + config = self.get_config() + + return config, pixel_values, labels + + def get_backbone_config(self): + return VitDetConfig( + image_size=self.image_size, + patch_size=self.patch_size, + num_channels=self.num_channels, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + hidden_size=self.hidden_size, + is_training=self.is_training, + hidden_act=self.hidden_act, + out_features=self.out_features, + ) + + def get_config(self): + return VitMatteConfig( + backbone_config=self.get_backbone_config(), + hidden_size=self.hidden_size, + fusion_hidden_sizes=self.fusion_hidden_sizes, + ) + + def create_and_check_model(self, config, pixel_values, labels): + model = VitMatteForImageMatting(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + self.parent.assertEqual(result.alphas.shape, (self.batch_size, 1, self.image_size, self.image_size)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values, labels = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class VitMatteModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + """ + Here we also overwrite some of the tests of test_modeling_common.py, as VitMatte does not use input_ids, inputs_embeds, + attention_mask and seq_length. + """ + + all_model_classes = (VitMatteForImageMatting,) if is_torch_available() else () + pipeline_model_mapping = {} + + fx_compatible = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + + def setUp(self): + self.model_tester = VitMatteModelTester(self) + self.config_tester = ConfigTester(self, config_class=VitMatteConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.create_and_test_config_common_properties() + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def create_and_test_config_common_properties(self): + return + + @unittest.skip(reason="VitMatte does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Training is not yet supported") + def test_training(self): + pass + + @unittest.skip(reason="Training is not yet supported") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="ViTMatte does not support input and output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in VITMATTE_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = VitMatteForImageMatting.from_pretrained(model_name) + self.assertIsNotNone(model) + + @unittest.skip(reason="ViTMatte does not support retaining gradient on attention logits") + def test_retain_grad_hidden_states_attentions(self): + pass + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + + expected_num_layers = getattr( + self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 + ) + self.assertEqual(len(hidden_states), expected_num_layers) + + self.assertListEqual( + list(hidden_states[0].shape[-2:]), + [2, 2], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + print("Hello we're here") + + check_hidden_states_output(inputs_dict, config, model_class) + + +@require_torch +class VitMatteModelIntegrationTest(unittest.TestCase): + @slow + def test_inference(self): + processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k") + model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k").to(torch_device) + + filepath = hf_hub_download( + repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset" + ) + image = Image.open(filepath).convert("RGB") + filepath = hf_hub_download( + repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset" + ) + trimap = Image.open(filepath).convert("L") + + # prepare image + trimap for the model + inputs = processor(images=image, trimaps=trimap, return_tensors="pt").to(torch_device) + + with torch.no_grad(): + alphas = model(**inputs).alphas + + expected_shape = torch.Size((1, 1, 640, 960)) + self.assertEqual(alphas.shape, expected_shape) + + expected_slice = torch.tensor( + [[0.9977, 0.9987, 0.9990], [0.9980, 0.9998, 0.9998], [0.9983, 0.9998, 0.9998]], device=torch_device + ) + self.assertTrue(torch.allclose(alphas[0, 0, :3, :3], expected_slice, atol=1e-4)) diff --git a/utils/check_repo.py b/utils/check_repo.py index f5d3f065d57b19..c8bd228eaa776e 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -277,6 +277,7 @@ "SpeechT5ForSpeechToSpeech", "SpeechT5ForTextToSpeech", "SpeechT5HifiGan", + "VitMatteForImageMatting", ] # DO NOT edit this list!