Skip to content

Commit

Permalink
Adds CLIP to models exportable with ONNX (huggingface#18515)
Browse files Browse the repository at this point in the history
* onnx config for clip

* default opset as 14

* changes from the original repo

* input values order fix

* outputs fix

* remove unused import

* ran make fix-copies

* black format

* review comments: forward ref, import fix, model change revert, .to cleanup

* make style

* formatting fixes

* revert groupvit

* comment for cast to int32

* comment fix

* make .T as .t() for onnx conversion

* ran make fix-copies

* remove unneeded comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix copies

* remove comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and oneraghavan committed Sep 26, 2022
1 parent 84101f7 commit a4ff5d0
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
- BlenderbotSmall
- BLOOM
- CamemBERT
- CLIP
- CodeGen
- ConvBERT
- ConvNeXT
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/models/clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@


_import_structure = {
"configuration_clip": ["CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", "CLIPConfig", "CLIPTextConfig", "CLIPVisionConfig"],
"configuration_clip": [
"CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CLIPConfig",
"CLIPOnnxConfig",
"CLIPTextConfig",
"CLIPVisionConfig",
],
"tokenization_clip": ["CLIPTokenizer"],
}

Expand Down Expand Up @@ -95,7 +101,13 @@


if TYPE_CHECKING:
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from .configuration_clip import (
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
CLIPConfig,
CLIPOnnxConfig,
CLIPTextConfig,
CLIPVisionConfig,
)
from .tokenization_clip import CLIPTokenizer

try:
Expand Down
50 changes: 49 additions & 1 deletion src/transformers/models/clip/configuration_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@

import copy
import os
from typing import Union
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union


if TYPE_CHECKING:
from ...processing_utils import ProcessorMixin
from ...utils import TensorType

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -317,3 +324,44 @@ def to_dict(self):
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output


class CLIPOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("logits_per_image", {0: "batch"}),
("logits_per_text", {0: "batch"}),
("text_embeds", {0: "batch"}),
("image_embeds", {0: "batch"}),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4

def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:

text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
return {**text_input_dict, **image_input_dict}

@property
def default_onnx_opset(self) -> int:
return 14
9 changes: 6 additions & 3 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.T)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0


Expand Down Expand Up @@ -660,7 +660,10 @@ def forward(

# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
]

if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
Expand Down Expand Up @@ -1050,7 +1053,7 @@ def forward(
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T
logits_per_image = logits_per_text.t()

loss = None
if return_loss:
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/groupvit/modeling_groupvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.T)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0


Expand Down Expand Up @@ -1132,7 +1132,10 @@ def forward(

# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
]

if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.T)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
# Copied from transformers.models.clip.modeling_clip.clip_loss
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.T)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0


Expand Down
4 changes: 4 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.camembert.CamembertOnnxConfig",
),
"clip": supported_features_mapping(
"default",
onnx_config_cls="models.clip.CLIPOnnxConfig",
),
"codegen": supported_features_mapping(
"default",
"causal-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def test_values_override(self):
("big-bird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("clip", "openai/clip-vit-base-patch32"),
("convbert", "YituTech/conv-bert-base"),
("codegen", "Salesforce/codegen-350M-multi"),
("deberta", "microsoft/deberta-base"),
Expand Down

0 comments on commit a4ff5d0

Please sign in to comment.