Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds CLIP to models exportable with ONNX #18515

Merged
merged 21 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
- BlenderbotSmall
- BLOOM
- CamemBERT
- CLIP
- CodeGen
- ConvBERT
- ConvNeXT
Expand Down
16 changes: 14 additions & 2 deletions src/transformers/models/clip/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@


_import_structure = {
"configuration_clip": ["CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", "CLIPConfig", "CLIPTextConfig", "CLIPVisionConfig"],
"configuration_clip": [
"CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
"CLIPConfig",
"CLIPOnnxConfig",
"CLIPTextConfig",
"CLIPVisionConfig",
],
"tokenization_clip": ["CLIPTokenizer"],
}

Expand Down Expand Up @@ -95,7 +101,13 @@


if TYPE_CHECKING:
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from .configuration_clip import (
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
CLIPConfig,
CLIPOnnxConfig,
CLIPTextConfig,
CLIPVisionConfig,
)
from .tokenization_clip import CLIPTokenizer

try:
Expand Down
50 changes: 49 additions & 1 deletion src/transformers/models/clip/configuration_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@

import copy
import os
from typing import Union
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union


if TYPE_CHECKING:
from ...processing_utils import ProcessorMixin
from ...utils import TensorType

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -317,3 +324,44 @@ def to_dict(self):
output["vision_config"] = self.vision_config.to_dict()
output["model_type"] = self.__class__.model_type
return output


class CLIPOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("pixel_values", {0: "batch"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)

@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("logits_per_image", {0: "batch"}),
("logits_per_text", {0: "batch"}),
("text_embeds", {0: "batch"}),
("image_embeds", {0: "batch"}),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4

def generate_dummy_inputs(
self,
processor: "ProcessorMixin",
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:

text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
return {**text_input_dict, **image_input_dict}

@property
def default_onnx_opset(self) -> int:
return 14
4 changes: 3 additions & 1 deletion src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,9 @@ def forward(

# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment here to say that the cast to(torch.int) is required for ONNX export.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, added a comment and pushed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually ONNX does support int64. But from what I read here, ArgMax does not support int64 inputs with opset 14. So I would just put something like that:

# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, my bad. updated the comment and pushed

]

if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.camembert.CamembertOnnxConfig",
),
"clip": supported_features_mapping(
"default",
onnx_config_cls="models.clip.CLIPOnnxConfig",
),
"codegen": supported_features_mapping(
"default",
"causal-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def test_values_override(self):
("big-bird", "google/bigbird-roberta-base"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("clip", "openai/clip-vit-base-patch32"),
("convbert", "YituTech/conv-bert-base"),
("codegen", "Salesforce/codegen-350M-multi"),
("deberta", "microsoft/deberta-base"),
Expand Down