-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* begin second draft * fix import, style * add loss * fix embeds, logits_scale, and projection * fix imports * add conversion script * add feature_extractor and processor * style * add tests for tokenizer, extractor and processor * add vision model tests * add weight init * add more tests * fix save_load test * model output, dosstrings, causal mask * config doc * add clip model tests * return dict * bigin integration test * add integration tests * fix-copies * fix init * Clip => CLIP * fix module name * docs * fix doc * output_dim => projection_dim * fix checkpoint names * remoe fast tokenizer file * fix conversion script * fix tests, quality * put causal mask on device * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix attribute test * style * address sylvains comments * style * fix docstrings * add qucik_gelu in activations, docstrings * clean-up attention test * fix act fun * fix config * fix torchscript tests * even batch_size * remove comment * fix ouput tu_tuple * fix save load tests * fix add tokens test * add fast tokenizer * update copyright * new processor API * fix docs * docstrings * docs * fix doc * fix doc * fix tokenizer * fix import in doc example * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * check types of config * valhalla => openai * load image using url * fix test * typo Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
- Loading branch information
1 parent
4ce6bcc
commit 8719afa
Showing
25 changed files
with
3,848 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
.. | ||
Copyright 2021 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
CLIP | ||
----------------------------------------------------------------------------------------------------------------------- | ||
|
||
Overview | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The CLIP model was proposed in `Learning Transferable Visual Models From Natural Language Supervision | ||
<https://arxiv.org/abs/2103.00020>`__ by Alec Radford, Jong Wook Kim, Chris Hallacy, Aditya Ramesh, Gabriel Goh, | ||
Sandhini Agarwal, Girish Sastry, Amanda Askell, Pamela Mishkin, Jack Clark, Gretchen Krueger, Ilya Sutskever. CLIP | ||
(Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be | ||
instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing | ||
for the task, similarly to the zero-shot capabilities of GPT-2 and 3. | ||
|
||
The abstract from the paper is the following: | ||
|
||
*State-of-the-art computer vision systems are trained to predict a fixed set of predetermined object categories. This | ||
restricted form of supervision limits their generality and usability since additional labeled data is needed to specify | ||
any other visual concept. Learning directly from raw text about images is a promising alternative which leverages a | ||
much broader source of supervision. We demonstrate that the simple pre-training task of predicting which caption goes | ||
with which image is an efficient and scalable way to learn SOTA image representations from scratch on a dataset of 400 | ||
million (image, text) pairs collected from the internet. After pre-training, natural language is used to reference | ||
learned visual concepts (or describe new ones) enabling zero-shot transfer of the model to downstream tasks. We study | ||
the performance of this approach by benchmarking on over 30 different existing computer vision datasets, spanning tasks | ||
such as OCR, action recognition in videos, geo-localization, and many types of fine-grained object classification. The | ||
model transfers non-trivially to most tasks and is often competitive with a fully supervised baseline without the need | ||
for any dataset specific training. For instance, we match the accuracy of the original ResNet-50 on ImageNet zero-shot | ||
without needing to use any of the 1.28 million training examples it was trained on. We release our code and pre-trained | ||
model weights at this https URL.* | ||
|
||
Usage | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
CLIP is a multi-modal vision and language model. It can be used for image-text similarity and for zero-shot image | ||
classification. CLIP uses a ViT like transformer to get visual features and a causal language model to get the text | ||
features. Both the text and visual features are then projected to a latent space with identical dimension. The dot | ||
product between the projected image and text features is then used as a similar score. | ||
|
||
To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches, | ||
which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image. The authors | ||
also add absolute position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. | ||
The :class:`~transformers.CLIPFeatureExtractor` can be used to resize (or rescale) and normalize images for the model. | ||
|
||
The :class:`~transformers.CLIPTokenizer` is used to encode the text. The :class:`~transformers.CLIPProcessor` wraps | ||
:class:`~transformers.CLIPFeatureExtractor` and :class:`~transformers.CLIPTokenizer` into a single instance to both | ||
encode the text and prepare the images. The following example shows how to get the image-text similarity scores using | ||
:class:`~transformers.CLIPProcessor` and :class:`~transformers.CLIPModel`. | ||
|
||
|
||
.. code-block:: | ||
>>> import torch | ||
>>> from PIL import Image | ||
>>> import requests | ||
>>> from transformers import CLIPProcessor, CLIPModel | ||
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | ||
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | ||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
>>> image = Image.open(requests.get(url, stream=True).raw) | ||
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True) | ||
>>> outputs = model(**inputs) | ||
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score | ||
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities | ||
This model was contributed by `valhalla <https://huggingface.co/valhalla>`__. The original code can be found `here | ||
<https://github.com/openai/CLIP>`__. | ||
|
||
CLIPConfig | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPConfig | ||
:members: from_text_vision_configs | ||
|
||
|
||
CLIPTextConfig | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPTextConfig | ||
:members: | ||
|
||
|
||
CLIPVisionConfig | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPVisionConfig | ||
:members: | ||
|
||
|
||
|
||
CLIPTokenizer | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPTokenizer | ||
:members: build_inputs_with_special_tokens, get_special_tokens_mask, | ||
create_token_type_ids_from_sequences, save_vocabulary | ||
|
||
CLIPTokenizerFast | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPTokenizerFast | ||
:members: | ||
|
||
|
||
CLIPFeatureExtractor | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPFeatureExtractor | ||
:members: | ||
|
||
|
||
CLIPProcessor | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPProcessor | ||
:members: | ||
|
||
|
||
|
||
CLIPModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPModel | ||
:members: forward, get_text_features, get_image_features | ||
|
||
|
||
CLIPTextModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPTextModel | ||
:members: forward | ||
|
||
|
||
CLIPVisionModel | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. autoclass:: transformers.CLIPVisionModel | ||
:members: forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
blenderbot, | ||
blenderbot_small, | ||
camembert, | ||
clip, | ||
convbert, | ||
cpm, | ||
ctrl, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.