-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Vision Transformer + ViTFeatureExtractor #10513
Add Vision Transformer + ViTFeatureExtractor #10513
Conversation
392546f
to
7b4f1c7
Compare
dfc6660
to
7d3fff0
Compare
Hey @NielsRogge
Some common modeling test depend on the specific parameter names, ( Also, the tests for You could use the modeling tests of |
I like the overall design We should also add the |
a830014
to
de78130
Compare
de78130
to
e01294c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this model! The main problem I have is with the self.self
for the self-attention. It's there in BERT and there is nothing we can do about it now, but we can still make sure to use another name in newer models!
"norm.weight", | ||
"norm.bias", | ||
"head.weight", | ||
"head.bias", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can all fit in one line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, it's make style
that does this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just tested locally and it did not change the line
ignore_keys = ["norm.weight", "norm.bias", "head.weight", "head.bias"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good, fantastic job @NielsRogge!
Related to the modelcard:
- The model doesn't have a model card as of now, it would be amazing to have it
- The model configuration on your
nielsr/vit-base-patch16-224
repo has all the labels as "LABELS_{i}", it would be great to have the actual label names!
Other than that it looks in very good shape! I'm wondering about the feature processor as I understand it, it's not framework agnostic. Also, the AutoModel is a very low hanging fruit when we already have the mapping.
@@ -319,6 +323,8 @@ TensorFlow and/or Flax. | |||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | |||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | |||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | |||
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing this is looks like the ViT support is quite incomplete, even though it's not the case. I think we should eventually rethink how this is designed so that feature processors are highlighted here. Maybe by modifying "Tokenizer slow" to be "Pre-processor" and "Tokenizer fast" to be "Performance-optimized pre-processor". Let's think about it cc @sgugger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for a further PR though ;-) But yes, definitely worth a look!
super().__init__() | ||
image_size = to_2tuple(image_size) | ||
patch_size = to_2tuple(patch_size) | ||
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here would be helpful
Thanks for the reviews, addressed most of the comments. To do:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are multiple instances of weird styling. In general we use the 119-chars line to its maximum (you can add a ruler in your IDE to see where it is). Sadly make style
does not put back code split into several lines back in one line if you are using code copied from another part of the lib as a base (where the split might be justified because there were more objects or longer names in the original) so it has to be done by hand.
@@ -319,6 +323,8 @@ TensorFlow and/or Flax. | |||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | |||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | |||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ | |||
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for a further PR though ;-) But yes, definitely worth a look!
embedding_output = self.embeddings( | ||
pixel_values, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
embedding_output = self.embeddings( | |
pixel_values, | |
) | |
embedding_output = self.embeddings(pixel_values) |
""" | ||
Decorator marking a test that requires torchvision. | ||
|
||
These tests are skipped when torchvision isn't installed. | ||
|
||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | |
Decorator marking a test that requires torchvision. | |
These tests are skipped when torchvision isn't installed. | |
""" | |
""" | |
Decorator marking a test that requires torchvision. These tests are skipped when torchvision isn't installed. | |
""" |
( | ||
config, | ||
pixel_values, | ||
labels, | ||
) = config_and_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
( | |
config, | |
pixel_values, | |
labels, | |
) = config_and_inputs | |
config, pixel_values, labels = config_and_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, this looks great! Looking forward to seeing @sgugger's take on the feature processor.
Played around with the model a bit, it's fun! Great job on the implementation @NielsRogge!
I've addressed all comments. Most important updates:
The remaining comments which are still open have to do with styling. I seem to have some issues with |
What does this PR do?
This PR includes 2 things:
it adds the Vision Transformer (ViT) by Google Brain. ViT is a Transformer encoder trained on ImageNet. It is capable of classifying images, by placing a linear classification head on top of the final hidden state of the [CLS] token. I converted the weights from the timm repository, which already took care of converting the weights of the original implementation (which is written in JAX) into PyTorch. Once this model is added, we can also add DeIT (Data-efficient Image Transformers) by Facebook AI, which improve upon ViT.
it provides a design for the
ViTFeatureExtractor
class, which can be used to prepare images for the model. It inherits fromFeatureExtractionMixin
and defines a__call__
method. It currently accepts 3 types of inputs: PIL images, Numpy arrays and PyTorch tensors. It defines 2 transformations usingtorchvision
: resizing + normalization. It then returns aBatchFeature
object with 1 key, namelypixel_values
.Demo notebook of combination of
ViTForImageClassification
+ViTFeatureExtractor
: https://colab.research.google.com/drive/16TCM-tJ1Mfhs00Qas063kWZmAtVJcOeP?usp=sharingCompared to NLP models (which accept
input_ids
,attention_mask
andtoken_type_ids
), this model only acceptspixel_values
. The model itself then converts these pixel values into patches (in case of ViT) in theViTEmbeddings
class.Help needed
Would be great if you can help me with the following tasks:
test_modeling_vit.py
,test_feature_extraction_vit.py
. However, for the former, since ViT does not useinput_ids
/input_embeds
, some tests are failing, so I wonder whether it should use all tests defined intest_modeling_common.py
. For the latter, I also need some help in creating random inputs to test the feature extractor on.head_mask
in the forward ofViTModel
. Possibly removeattention_mask
?make fix-copies
(doesn't work right now for me on Windows)is_decoder
logic frommodeling_vit.py
(since the model was created using the CookieCutter template). I assume that things such aspast_key_values
are not required for an encoder-only model.Who can review?
@patrickvonplaten @LysandreJik @sgugger