Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Vision Transformer + ViTFeatureExtractor #10513

Closed
Closed
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ffa4bf8
Fix rebase with master
NielsRogge Mar 15, 2021
a8d48c2
Add List typing hint
NielsRogge Mar 3, 2021
8363469
Remove annotations
NielsRogge Mar 3, 2021
f4e4fb3
Potential bug fix
NielsRogge Mar 3, 2021
ff97a92
Bug fix
NielsRogge Mar 3, 2021
46ea2b6
Rename inputs to pixel_values
NielsRogge Mar 3, 2021
b6fba1c
First draft of ImageProcessor tests
NielsRogge Mar 3, 2021
bc6f12d
Clean up: remove print statements, remove unused variables
NielsRogge Mar 4, 2021
56ccfa8
Remove load_tf_weights_in_vit
NielsRogge Mar 4, 2021
dc3c23f
Rename pixel_mask to attention_mask
NielsRogge Mar 4, 2021
d3607b4
Improve tests
NielsRogge Mar 4, 2021
dca36be
Small cleanup
NielsRogge Mar 4, 2021
9f40352
Remove is_decoder logic and make style
NielsRogge Mar 4, 2021
6da3261
Fix another rebase issue
NielsRogge Mar 15, 2021
d48609a
Fix another rebase issue
NielsRogge Mar 15, 2021
43524d0
Major cleanup - renamed ViTImageProcessor to ViTFeatureExtractor
NielsRogge Mar 16, 2021
b168ee4
Add torch.stack
NielsRogge Mar 16, 2021
c39f155
Add documentation
NielsRogge Mar 16, 2021
20e3d1e
Remove test_image_processor_common
NielsRogge Mar 16, 2021
b2b3432
Improve model tests
NielsRogge Mar 16, 2021
43ba11f
Add is_torchvision_available to general init of vit
NielsRogge Mar 16, 2021
4fb8def
Fix import of ViTFeatureExtractor
NielsRogge Mar 16, 2021
4c91fb3
Fix another bug with init
NielsRogge Mar 16, 2021
c3dfbe6
Use append instead of extend
NielsRogge Mar 16, 2021
5cd7dfd
Make all tests of ViTFeatureExtractor pass
NielsRogge Mar 17, 2021
9637b85
Improve model tests
NielsRogge Mar 17, 2021
872ae16
24 model tests pass, 6 fail on cpu
NielsRogge Mar 17, 2021
a7a9e0e
Minor fixes
NielsRogge Mar 17, 2021
466cef1
Improve tests
NielsRogge Mar 19, 2021
647f0e4
All tests are passing
NielsRogge Mar 19, 2021
e01294c
Make style & quality, docs improvements
NielsRogge Mar 19, 2021
0e02f64
Remove attention mask, add support for head mask
NielsRogge Mar 20, 2021
02c06bc
Merge branch 'master' into modeling_vit_pytorch_v2
NielsRogge Mar 20, 2021
f5ba2f4
Some docs improvements + clearer input checking for ViTFeatureExtractor
NielsRogge Mar 21, 2021
852b777
Change normalization to match original implementation
NielsRogge Mar 21, 2021
03b7638
Fix bugs in tests
NielsRogge Mar 22, 2021
f6556b5
One more bug fix
NielsRogge Mar 22, 2021
884b7a7
Revert previous change
NielsRogge Mar 22, 2021
f35360e
Address most comments by @sgugger @LysandreJik
NielsRogge Mar 22, 2021
f9a1ac6
Update conversion script
NielsRogge Mar 23, 2021
472f96d
Rename self.self to self.attention
NielsRogge Mar 23, 2021
e790c1d
Add pooler option to ViTForImageClassification, improve docs
NielsRogge Mar 24, 2021
8b95a1e
Add ViTFeatureExtractor to conversion script
NielsRogge Mar 24, 2021
37ae119
Add copyright
NielsRogge Mar 24, 2021
c6c0f27
Address additional comments
NielsRogge Mar 26, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand All @@ -107,8 +107,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
paths:
Expand All @@ -135,8 +135,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
Expand Down Expand Up @@ -215,8 +215,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
paths:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[XLM-RoBERTa](https://huggingface.co/transformers/model_doc/xlmroberta.html)** (from Facebook AI), released together with the paper [Unsupervised Cross-lingual Representation Learning at Scale](https://arxiv.org/abs/1911.02116) by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov.
1. **[XLNet](https://huggingface.co/transformers/model_doc/xlnet.html)** (from Google/CMU) released with the paper [​XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
1. **[XLSR-Wav2Vec2](https://huggingface.co/transformers/model_doc/xlsr_wav2vec2.html)** (from Facebook AI) released with the paper [Unsupervised Cross-Lingual Representation Learning For Speech Recognition](https://arxiv.org/abs/2006.13979) by Alexis Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR.

To check if each model has an implementation in PyTorch/TensorFlow/Flax or has an associated tokenizer backed by the 🤗 Tokenizers library, refer to [this table](https://huggingface.co/transformers/index.html#bigtable)
Expand Down
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ and conversion utilities for the following models:
47. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
48. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.


.. _bigtable:
Expand Down Expand Up @@ -319,6 +323,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this is looks like the ViT support is quite incomplete, even though it's not the case. I think we should eventually rethink how this is designed so that feature processors are highlighted here. Maybe by modifying "Tokenizer slow" to be "Pre-processor" and "Tokenizer fast" to be "Performance-optimized pre-processor". Let's think about it cc @sgugger

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a further PR though ;-) But yes, definitely worth a look!

+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
Expand Down Expand Up @@ -448,6 +454,7 @@ TensorFlow and/or Flax.
model_doc/t5
model_doc/tapas
model_doc/transformerxl
model_doc/vit
model_doc/wav2vec2
model_doc/xlm
model_doc/xlmprophetnet
Expand Down
94 changes: 94 additions & 0 deletions docs/source/model_doc/vit.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
..
Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

Vision Transformer (ViT)
-----------------------------------------------------------------------------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The Vision Transformer (ViT) model was proposed in `An Image is Worth 16x16 Words: Transformers for Image Recognition
at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk
Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob
Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining
very good results compared to familiar convolutional architectures.


The abstract from the paper is the following:

*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its
applications to computer vision remain limited. In vision, attention is either applied in conjunction with
convolutional networks, or used to replace certain components of convolutional networks while keeping their overall
structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to
sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of
data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.),
Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring
substantially fewer computational resources to train.*

Tips:

- To feed images to the Transformer encoder, each image is split into a sequence of fixed-size patches, which are then
linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be used for
classification. The authors also add absolute position embeddings, and feed the resulting sequence of vectors to a
standard Transformer encoder.
- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to
use a higher resolution than pre-training `(Touvron et al., 2019) <https://arxiv.org/abs/1906.06423>`__, `(Kolesnikov
et al., 2020) <https://arxiv.org/abs/1912.11370>`__. The authors report the best results with a resolution of 384x384
during fine-tuning.
- As the Vision Transformer expects each image to be of the same size (resolution), one can use
:class:`~transformers.ViTFeatureExtractor` to resize (or rescale) and normalize images for the model.
- Both the patch resolution and image resolution used during fine-tuning are reflected in the name of each checkpoint.
For example, :obj:`google/vit-base-patch16-224` refers to a base architecture with patch resolution of 16x16 and
fine-tuning resolution of 224x224. All checkpoints can be found on the `hub
<https://huggingface.co/models?search=vit>`__.
- The available checkpoints are pre-trained on `ImageNet-21k <http://www.image-net.org/>`__ (a collection of 14 million
images and 21k classes), and then fine-tuned on `ImageNet <http://www.image-net.org/challenges/LSVRC/2012/>`__ (also
referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes).
- The best results are obtained with supervised pre-training, which is not the case in NLP. The authors did also
experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.


The original code (written in JAX) can be found `here <https://github.com/google-research/vision_transformer>`__.

Note that we converted the weights from Ross Wightman's `timm library
<https://github.com/rwightman/pytorch-image-models>`__, who already converted the weights from JAX to PyTorch. Credits
go to him!


ViTConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ViTConfig
:members:


ViTFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ViTFeatureExtractor
:members: __call__


ViTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ViTModel
:members: forward


ViTForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.ViTForImageClassification
:members: forward
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"tokenizers>=0.10.1,<0.11",
"torch>=1.0",
"torchaudio",
"torchvision",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
Expand Down Expand Up @@ -225,6 +226,7 @@ def run(self):

extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile", "torchaudio")
extras["vision"] = deps_list("torchvision")

extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (
Expand All @@ -235,7 +237,7 @@ def run(self):
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
extras["quality"] = deps_list("black", "isort", "flake8")

extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"]
extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"] + extras["speech"] + extras["vision"]
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved

extras["dev"] = (
extras["all"]
Expand Down
30 changes: 29 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
is_tf_available,
is_tokenizers_available,
is_torch_available,
is_torchvision_available,
)
from .utils import logging

Expand Down Expand Up @@ -105,6 +106,7 @@
"is_tokenizers_available",
"is_torch_available",
"is_torch_tpu_available",
"is_torchvision_available",
],
"hf_argparser": ["HfArgumentParser"],
"integrations": [
Expand Down Expand Up @@ -209,6 +211,7 @@
"TransfoXLCorpus",
"TransfoXLTokenizer",
],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.wav2vec2": [
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2Config",
Expand Down Expand Up @@ -295,7 +298,7 @@
name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
]

# tokenziers-backed objects
# tokenizers-backed objects
if is_tokenizers_available():
# Fast tokenizers
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
Expand Down Expand Up @@ -408,6 +411,7 @@
_import_structure["models.auto"].extend(
[
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
Expand Down Expand Up @@ -824,6 +828,15 @@
"load_tf_weights_in_transfo_xl",
]
)
_import_structure["models.vit"].extend(
[
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTForImageClassification",
"ViTLayer",
"ViTModel",
"ViTPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -1245,6 +1258,9 @@
name for name in dir(dummy_flax_objects) if not name.startswith("_")
]

# Torchvision-backed objects
if is_torchvision_available():
_import_structure["models.vit"].append("ViTFeatureExtractor")

# Direct imports for type-checking
if TYPE_CHECKING:
Expand Down Expand Up @@ -1410,6 +1426,7 @@
TransfoXLCorpus,
TransfoXLTokenizer,
)
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig, ViTFeatureExtractor
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2Config,
Expand Down Expand Up @@ -1530,6 +1547,9 @@
else:
from .utils.dummy_tokenizers_objects import *

if is_torchvision_available():
from .models.vit import ViTFeatureExtractor

NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
# Modeling
if is_torch_available():

Expand Down Expand Up @@ -1589,6 +1609,7 @@
)
from .models.auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
Expand Down Expand Up @@ -1927,6 +1948,13 @@
TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl,
)
from .models.vit import (
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTForImageClassification,
ViTLayer,
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
ViTModel,
ViTPreTrainedModel,
)
from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"tokenizers": "tokenizers>=0.10.1,<0.11",
"torch": "torch>=1.0",
"torchaudio": "torchaudio",
"torchvision": "torchvision",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
Expand Down
17 changes: 15 additions & 2 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,23 @@
except importlib_metadata.PackageNotFoundError:
_soundfile_available = False

_torchaudio_available = importlib.util.find_spec("torchaudio")

_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
try:
_torchaudio_version = importlib_metadata.version("torchaudio")
logger.debug(f"Successfully imported soundfile version {_torchaudio_version}")
logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False


_torchvision_available = importlib.util.find_spec("torchvision") is not None
try:
_torchvision_version = importlib_metadata.version("torchvision")
logger.debug(f"Successfully imported torchvision version {_torchvision_version}")
except importlib_metadata.PackageNotFoundError:
_torchvision_available = False


torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
Expand Down Expand Up @@ -381,6 +390,10 @@ def is_torchaudio_available():
return _torchaudio_available


def is_torchvision_available():
return _torchvision_available


def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
t5,
tapas,
transfo_xl,
vit,
wav2vec2,
xlm,
xlm_roberta,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
if is_torch_available():
_import_structure["modeling_auto"] = [
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
Expand Down Expand Up @@ -90,6 +91,7 @@
if is_torch_available():
from .modeling_auto import (
NielsRogge marked this conversation as resolved.
Show resolved Hide resolved
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
Expand Down
Loading