Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Onnx clip]Adding the clip_onnx to our avaible models for faster inference #245

Merged
merged 25 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/marqo/s2_inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from marqo.s2_inference.random_utils import Random
from marqo.s2_inference.clip_utils import CLIP, OPEN_CLIP
from marqo.s2_inference.types import Any, Dict, List, Optional, Union, FloatTensor
from marqo.s2_inference.onnx_clip_utils import CLIP_ONNX

# we need to keep track of the embed dim and model load functions/classes
# we can use this as a registry
Expand Down Expand Up @@ -511,6 +512,25 @@ def _get_sbert_test_properties() -> Dict:
}
return TEST_MODEL_PROPERTIES

def _get_onnx_clip_properties() -> Dict:
ONNX_CLIP_MODEL_PROPERTIES = {
"onnx32/openai/ViT-L/14":
{
"name":"onnx32/openai/ViT-L/14",
"dimensions" : 768,
"type":"clip_onnx",
"note":"the onnx float32 version of openai ViT-L/14"
},
"onnx16/openai/ViT-L/14":
{
"name": "onnx16/openai/ViT-L/14",
"dimensions": 768,
"type": "clip_onnx",
"note": "the onnx float16 version of openai ViT-L/14"
},
}
return ONNX_CLIP_MODEL_PROPERTIES

def _get_random_properties() -> Dict:
RANDOM_MODEL_PROPERTIES = {
"random":
Expand Down Expand Up @@ -547,6 +567,7 @@ def _get_model_load_mappings() -> Dict:
'sbert':SBERT,
'test':TEST,
'sbert_onnx':SBERT_ONNX,
'clip_onnx': CLIP_ONNX,
'random':Random,
'hf':HF_MODEL}

Expand All @@ -562,6 +583,7 @@ def load_model_properties() -> Dict:
random_model_properties = _get_random_properties()
hf_model_properties = _get_hf_properties()
open_clip_model_properties = _get_open_clip_properties()
onnx_clip_model_properties = _get_onnx_clip_properties()

# combine the above dicts
model_properties = dict(clip_model_properties.items())
Expand All @@ -571,6 +593,7 @@ def load_model_properties() -> Dict:
model_properties.update(random_model_properties)
model_properties.update(hf_model_properties)
model_properties.update(open_clip_model_properties)
model_properties.update(onnx_clip_model_properties)

all_properties = dict()
all_properties['models'] = model_properties
Expand Down
177 changes: 177 additions & 0 deletions src/marqo/s2_inference/onnx_clip_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# from torch import FloatTensor
# from typing import Any, Dict, List, Optional, Union
import onnx
import os
import validators
import requests
import numpy as np
import clip
import torch
from PIL import Image
import open_clip
from huggingface_hub import hf_hub_download
from marqo.s2_inference.types import *
from marqo.s2_inference.logger import get_logger
import onnxruntime as ort

# Loading shared functions from clip_utils.py. This part should be decoupled from models in the future
from marqo.s2_inference.clip_utils import get_allowed_image_types, format_and_load_CLIP_image, format_and_load_CLIP_images, load_image_from_path,_is_image

logger = get_logger(__name__)

_HF_MODEL_DOWNLOAD = {

#Please check the link https://huggingface.co/Marqo for available models.


"onnx32/openai/ViT-L/14":
{
"repo_id": "Marqo/onnx-openai-ViT-L-14",
"visual_file": "onnx32-openai-ViT-L-14-visual.onnx",
"textual_file": "onnx32-openai-ViT-L-14-textual.onnx",
"token": None
},

"onnx16/openai/ViT-L/14":
{
"repo_id": "Marqo/onnx-openai-ViT-L-14",
"visual_file": "onnx16-openai-ViT-L-14-visual.onnx",
"textual_file": "onnx16-openai-ViT-L-14-textual.onnx",
"token": None

}
}


class CLIP_ONNX(object):
"""
Load a clip model and convert it to onnx version for faster inference
"""

def __init__(self, model_name = "onnx32/openai/ViT-L/14", device = "cpu", embedding_dim: int = None, truncate: bool = True,
load=True, **kwargs):
self.model_name = model_name
self.onnx_type, self.source, self.clip_model = self.model_name.split("/", 2)
self.device = device
self.truncate = truncate
self.provider = ['CUDAExecutionProvider', "CPUExecutionProvider"] if self.device.startswith("cuda") else ["CPUExecutionProvider"]
self.visual_session = None
self.textual_session = None
self.model_info = _HF_MODEL_DOWNLOAD[self.model_name]

if self.onnx_type == "onnx16":
self.visual_type = np.float16
elif self.onnx_type == "onnx32":
self.visual_type = np.float32

def load(self):
self.load_clip()
self.load_onnx()

@staticmethod
def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)

def _convert_output(self, output):
if self.device == 'cpu':
return output.numpy()
elif self.device.startswith('cuda'):
return output.cpu().numpy()

def load_clip(self):
if self.source == "openai":
clip_model, self.clip_preprocess = clip.load(self.clip_model, device="cpu", jit=False)
self.tokenizer = clip.tokenize
del clip_model
elif self.source =="open_clip":
clip_name, pre_trained = self.clip_model.split("/", 2)
clip_model, _, self.clip_preprocess = open_clip.create_model_and_transforms(clip_name, pre_trained, device="cpu")
self.tokenizer = open_clip.get_tokenizer(clip_name)
del clip_model

def encode_text(self, sentence, normalize=True):
text = clip.tokenize(sentence, truncate=self.truncate).cpu()
text_onnx = text.detach().cpu().numpy().astype(np.int32)

onnx_input_text = {self.textual_session.get_inputs()[0].name: text_onnx}
# The onnx output has the shape [1,1,768], we need to squeeze the dimension
outputs = torch.squeeze(torch.tensor(np.array(self.textual_session.run(None, onnx_input_text)))).to(torch.float32)

if normalize:
print("we are normalizing")
_shape_before = outputs.shape
print(torch.linalg.norm(outputs))
outputs /= self.normalize(outputs)
print(torch.linalg.norm(outputs))
assert outputs.shape == _shape_before
return self._convert_output(outputs)

def encode_image(self, images, normalize=True):
if isinstance(images, list):
image_input = format_and_load_CLIP_images(images)
else:
image_input = [format_and_load_CLIP_image(images)]

image_input_processed = torch.stack([self.clip_preprocess(_img) for _img in image_input])
images_onnx = image_input_processed.detach().cpu().numpy().astype(self.visual_type)

onnx_input_image = {self.visual_session.get_inputs()[0].name: images_onnx}
# The onnx output has the shape [1,1,768], we need to squeeze the dimension
outputs = torch.squeeze(torch.tensor(np.array(self.visual_session.run(None, onnx_input_image)))).to(torch.float32)

if normalize:
_shape_before = outputs.shape
outputs /= self.normalize(outputs)
assert outputs.shape == _shape_before

return self._convert_output(outputs)

def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
default: str = 'text', normalize=True, **kwargs) -> FloatTensor:

if self.clip_preprocess is None or self.tokenizer is None:
self.load_clip()
if self.visual_session is None or self.textual_session is None:
self.load_onnx()

infer = kwargs.pop('infer', True)

if infer and _is_image(inputs):
is_image = True
else:
is_image = False
if default == 'text':
is_image = False
elif default == 'image':
is_image = True
else:
raise ValueError(f"expected default='image' or default='text' but received {default}")

if is_image:
logger.debug('image')
return self.encode_image(inputs, normalize=True)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=True)

def load_onnx(self):
self.visual_file = self.download_model(self.model_info["repo_id"], self.model_info["visual_file"])
self.textual_file = self.download_model(self.model_info["repo_id"], self.model_info["textual_file"])
self.visual_session = ort.InferenceSession(self.visual_file, providers=self.provider)
self.textual_session = ort.InferenceSession(self.textual_file, providers=self.provider)

@staticmethod
def download_model(repo_id:str, filename:str, cache_folder:str = None) -> str:
file_path = hf_hub_download(repo_id=repo_id, filename=filename,
cache_dir=cache_folder)
return file_path










4 changes: 3 additions & 1 deletion src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from marqo.s2_inference.configs import get_default_device, get_default_normalization, get_default_seq_length
from marqo.s2_inference.types import *
from marqo.s2_inference.logger import get_logger
from timeit import default_timer as timer

logger = get_logger(__name__)

Expand Down Expand Up @@ -35,17 +36,18 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties:
Raises:
VectoriseError: if the content can't be vectorised, for some reason.
"""

validated_model_properties = _validate_model_properties(model_name, model_properties)
model_cache_key = _create_model_cache_key(model_name, device, validated_model_properties)

_update_available_models(model_cache_key, model_name, validated_model_properties, device, normalize_embeddings)


try:
vectorised = available_models[model_cache_key].encode(content, normalize=normalize_embeddings, **kwargs)
except UnidentifiedImageError as e:
raise VectoriseError from e


return _convert_vectorized_output(vectorised)


Expand Down
29 changes: 27 additions & 2 deletions tests/s2_inference/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def test_compare_onnx_sbert_text_models(self):
assert abs(model_onnx.encode(sentence) - model_sbert.encode(sentence)).sum() < eps

def test_model_outputs(self):
names = ['open_clip/ViT-B-32/laion400m_e32', "all-MiniLM-L6-v1",
names = ["onnx32/openai/ViT-L/14", "onnx16/openai/ViT-L/14",
'open_clip/ViT-B-32/laion400m_e32', "all-MiniLM-L6-v1",
"all_datasets_v4_MiniLM-L6", "hf/all-MiniLM-L6-v1",
"hf/all_datasets_v4_MiniLM-L6", "onnx/all-MiniLM-L6-v1", "onnx/all_datasets_v4_MiniLM-L6"]
sentences = ['hello', 'this is a test sentence. so is this.', ['hello', 'this is a test sentence. so is this.']]
Expand All @@ -121,7 +122,8 @@ def test_model_outputs(self):
assert _check_output_type(_convert_vectorized_output(output))

def test_model_normalization(self):
names = ['open_clip/ViT-B-32/laion400m_e32', 'RN50', "ViT-B/16", "all-MiniLM-L6-v1",
names = ["onnx32/openai/ViT-L/14", "onnx16/openai/ViT-L/14",
'open_clip/ViT-B-32/laion400m_e32', 'RN50', "ViT-B/16", "all-MiniLM-L6-v1",
"all_datasets_v4_MiniLM-L6", "hf/all-MiniLM-L6-v1", "hf/all_datasets_v4_MiniLM-L6",
"onnx/all-MiniLM-L6-v1", "onnx/all_datasets_v4_MiniLM-L6"]
sentences = ['hello', 'this is a test sentence. so is this.', ['hello', 'this is a test sentence. so is this.']]
Expand Down Expand Up @@ -203,3 +205,26 @@ def test_open_clip_embedding_size(self):
output_dimension = len(output_v[0])

assert registered_dimension == output_dimension

def test_onnx_clip_vectorise(self):

names = ["onnx32/openai/ViT-L/14", "onnx16/openai/ViT-L/14"]

sentences = ['hello', 'this is a test sentence. so is this.',
['hello', 'this is a test sentence. so is this.']]
device = 'cpu'
eps = 1e-9

for name in names:
model_properties = get_model_properties_from_registry(name)
model = _load_model(model_properties['name'], model_properties=model_properties, device=device)

for sentence in sentences:
output_v = vectorise(name, sentence, model_properties, device, normalize_embeddings=True)

assert _check_output_type(output_v)

output_m = model.encode(sentence, normalize=True)

assert abs(torch.FloatTensor(output_m) - torch.FloatTensor(output_v)).sum() < eps