Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Vision Transformer + ViTFeatureExtractor #10513

Closed

Conversation

NielsRogge
Copy link
Contributor

@NielsRogge NielsRogge commented Mar 4, 2021

What does this PR do?

This PR includes 2 things:

  • it adds the Vision Transformer (ViT) by Google Brain. ViT is a Transformer encoder trained on ImageNet. It is capable of classifying images, by placing a linear classification head on top of the final hidden state of the [CLS] token. I converted the weights from the timm repository, which already took care of converting the weights of the original implementation (which is written in JAX) into PyTorch. Once this model is added, we can also add DeIT (Data-efficient Image Transformers) by Facebook AI, which improve upon ViT.

  • it provides a design for the ViTFeatureExtractor class, which can be used to prepare images for the model. It inherits from FeatureExtractionMixin and defines a __call__ method. It currently accepts 3 types of inputs: PIL images, Numpy arrays and PyTorch tensors. It defines 2 transformations using torchvision: resizing + normalization. It then returns a BatchFeature object with 1 key, namely pixel_values.

Demo notebook of combination of ViTForImageClassification + ViTFeatureExtractor: https://colab.research.google.com/drive/16TCM-tJ1Mfhs00Qas063kWZmAtVJcOeP?usp=sharing

Compared to NLP models (which accept input_ids, attention_mask and token_type_ids), this model only accepts pixel_values. The model itself then converts these pixel values into patches (in case of ViT) in the ViTEmbeddings class.

Help needed

Would be great if you can help me with the following tasks:

  • Add and improve tests. Currently I have defined the following tests: test_modeling_vit.py, test_feature_extraction_vit.py. However, for the former, since ViT does not use input_ids/input_embeds, some tests are failing, so I wonder whether it should use all tests defined in test_modeling_common.py. For the latter, I also need some help in creating random inputs to test the feature extractor on.
  • Add support for head_mask in the forward of ViTModel. Possibly remove attention_mask?
  • Run make fix-copies (doesn't work right now for me on Windows)
  • Remove the is_decoder logic from modeling_vit.py (since the model was created using the CookieCutter template). I assume that things such as past_key_values are not required for an encoder-only model.

Who can review?

@patrickvonplaten @LysandreJik @sgugger

@NielsRogge NielsRogge changed the title Add Vision Transformer + PreTrainedImageProcessor [WIP] Add Vision Transformer + PreTrainedImageProcessor Mar 4, 2021
@NielsRogge NielsRogge force-pushed the modeling_vit_pytorch_v2 branch from 392546f to 7b4f1c7 Compare March 4, 2021 17:21
@NielsRogge NielsRogge force-pushed the modeling_vit_pytorch_v2 branch 2 times, most recently from dfc6660 to 7d3fff0 Compare March 16, 2021 09:36
@NielsRogge NielsRogge changed the title [WIP] Add Vision Transformer + PreTrainedImageProcessor [WIP] Add Vision Transformer + ViTFeatureExtractor Mar 16, 2021
@patil-suraj
Copy link
Contributor

patil-suraj commented Mar 16, 2021

Hey @NielsRogge

Add and improve tests. Currently I have defined the following tests: test_modeling_vit.py, test_feature_extraction_vit.py. However, for the former, since ViT does not use input_ids/input_embeds, some tests are failing, so I wonder whether it should use all tests defined in test_modeling_common.py. For the latter, I also need some help in creating random inputs to test the feature extractor on.

Some common modeling test depend on the specific parameter names, (input_ids, input_embeds). You could just override such tests in your test class and use the correct parameter names. For example the test_forward_signature test
expects inputs_ids, so it should be overridden in your class to expect input_values.

Also, the tests for input_embeds (for example test_inputs_embeds) can be skipped since ViT does not use those. Agin just overrides the test and use pass in the method body.

You could use the modeling tests of Wav2Vec2 and Speech2Text for reference since those models also use different parameter names.

@patil-suraj
Copy link
Contributor

patil-suraj commented Mar 16, 2021

I like the overall design ViTFeatureExtractor. Regrading the import ViTFeatureExtractor
I think it should be always imported in the init files, and instead, ViTFeatureExtractor could check for torchvision and raise if it’s not installed. Otherwise, the TF tests on CI will fail because they won’t be able to import ViTFeatureExtractor as we don’t install torchvision in TF tests.

We should also add the torchvision and PIL dependency in the setup.py file as extras["vision"] and also add it in config.yaml for CI

@NielsRogge NielsRogge force-pushed the modeling_vit_pytorch_v2 branch 5 times, most recently from a830014 to de78130 Compare March 19, 2021 16:48
@NielsRogge NielsRogge force-pushed the modeling_vit_pytorch_v2 branch from de78130 to e01294c Compare March 19, 2021 19:58
@NielsRogge NielsRogge changed the title [WIP] Add Vision Transformer + ViTFeatureExtractor Add Vision Transformer + ViTFeatureExtractor Mar 22, 2021
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this model! The main problem I have is with the self.self for the self-attention. It's there in BERT and there is nothing we can do about it now, but we can still make sure to use another name in newer models!

src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/vit/configuration_vit.py Outdated Show resolved Hide resolved
Comment on lines +131 to +134
"norm.weight",
"norm.bias",
"head.weight",
"head.bias",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can all fit in one line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, it's make style that does this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested locally and it did not change the line

    ignore_keys = ["norm.weight", "norm.bias", "head.weight", "head.bias"]

src/transformers/models/vit/feature_extraction_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/feature_extraction_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/modeling_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/modeling_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/modeling_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/modeling_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/modeling_vit.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good, fantastic job @NielsRogge!

Related to the modelcard:

  • The model doesn't have a model card as of now, it would be amazing to have it
  • The model configuration on your nielsr/vit-base-patch16-224 repo has all the labels as "LABELS_{i}", it would be great to have the actual label names!

Other than that it looks in very good shape! I'm wondering about the feature processor as I understand it, it's not framework agnostic. Also, the AutoModel is a very low hanging fruit when we already have the mapping.

.circleci/config.yml Show resolved Hide resolved
src/transformers/models/vit/modeling_vit.py Outdated Show resolved Hide resolved
@@ -319,6 +323,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this is looks like the ViT support is quite incomplete, even though it's not the case. I think we should eventually rethink how this is designed so that feature processors are highlighted here. Maybe by modifying "Tokenizer slow" to be "Pre-processor" and "Tokenizer fast" to be "Performance-optimized pre-processor". Let's think about it cc @sgugger

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a further PR though ;-) But yes, definitely worth a look!

setup.py Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Show resolved Hide resolved
src/transformers/models/vit/configuration_vit.py Outdated Show resolved Hide resolved
src/transformers/models/vit/configuration_vit.py Outdated Show resolved Hide resolved
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here would be helpful

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Mar 22, 2021

Thanks for the reviews, addressed most of the comments. To do:

  • rename self.self to self.attention and update conversion script accordingly
  • convert more models, place them under the google namespace
  • add model cards
  • add 1,000 ImageNet class names to config

docs/source/model_doc/vit.rst Show resolved Hide resolved
docs/source/model_doc/vit.rst Outdated Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple instances of weird styling. In general we use the 119-chars line to its maximum (you can add a ruler in your IDE to see where it is). Sadly make style does not put back code split into several lines back in one line if you are using code copied from another part of the lib as a base (where the split might be justified because there were more objects or longer names in the original) so it has to be done by hand.

@@ -319,6 +323,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for a further PR though ;-) But yes, definitely worth a look!

docs/source/model_doc/vit.rst Outdated Show resolved Hide resolved
src/transformers/__init__.py Outdated Show resolved Hide resolved
src/transformers/models/vit/convert_vit_timm_to_pytorch.py Outdated Show resolved Hide resolved
src/transformers/models/vit/convert_vit_timm_to_pytorch.py Outdated Show resolved Hide resolved
Comment on lines +521 to +523
embedding_output = self.embeddings(
pixel_values,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
embedding_output = self.embeddings(
pixel_values,
)
embedding_output = self.embeddings(pixel_values)

Comment on lines +245 to +250
"""
Decorator marking a test that requires torchvision.

These tests are skipped when torchvision isn't installed.

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Decorator marking a test that requires torchvision.
These tests are skipped when torchvision isn't installed.
"""
"""
Decorator marking a test that requires torchvision. These tests are skipped when torchvision isn't installed.
"""

Comment on lines +127 to +131
(
config,
pixel_values,
labels,
) = config_and_inputs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(
config,
pixel_values,
labels,
) = config_and_inputs
config, pixel_values, labels = config_and_inputs

tests/test_modeling_vit.py Outdated Show resolved Hide resolved
tests/test_modeling_vit.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, this looks great! Looking forward to seeing @sgugger's take on the feature processor.

Played around with the model a bit, it's fun! Great job on the implementation @NielsRogge!

src/transformers/models/auto/__init__.py Show resolved Hide resolved
src/transformers/models/auto/modeling_auto.py Show resolved Hide resolved
src/transformers/models/vit/__init__.py Outdated Show resolved Hide resolved
@NielsRogge
Copy link
Contributor Author

I've addressed all comments. Most important updates:

  • moved the ImageNet id to classes dict to a new file under transformers.utils named imagenet_classes.py.
  • added a warning to the __call__ method of ViTFeatureExtractor to indicate that NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so it's most efficient to pass in PIL images.

The remaining comments which are still open have to do with styling. I seem to have some issues with make style. The max_length is set to 119, so not sure what's causing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants