Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Generic clip #286

Merged
merged 54 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
8bea5c8
add large scale test
wanliAlex Jan 24, 2023
2a07546
add fp16 model support
wanliAlex Jan 24, 2023
512e8b7
add fp16 model support
wanliAlex Jan 24, 2023
c351f32
add fp16 model support
wanliAlex Jan 24, 2023
93d5725
add fp16 model support
wanliAlex Jan 24, 2023
0692ed1
add fp16 model support
wanliAlex Jan 24, 2023
08f9336
add fp16 model support
wanliAlex Jan 24, 2023
f6cb5b0
add fp16 model support
wanliAlex Jan 24, 2023
a14c985
add fp16 model support
wanliAlex Jan 24, 2023
992d76e
add fp16 model support
wanliAlex Jan 24, 2023
0c49093
add fp16 model support
wanliAlex Jan 25, 2023
95c0817
add fp16 model support
wanliAlex Jan 25, 2023
83e8d8d
add fp16 model support
wanliAlex Jan 25, 2023
cc17847
generic clip revise
wanliAlex Jan 25, 2023
dab1c22
generic clip revise
wanliAlex Jan 25, 2023
ff4d7b9
generic clip revise
wanliAlex Jan 25, 2023
93cbf80
generic clip revise
wanliAlex Jan 25, 2023
3682c37
generic clip revise
wanliAlex Jan 25, 2023
9a1817c
add generic clip model tests
wanliAlex Jan 25, 2023
3b74215
add generic clip model tests
wanliAlex Jan 25, 2023
8c48b53
open_clip finish
wanliAlex Jan 27, 2023
e20c969
open_clip finish
wanliAlex Jan 27, 2023
c157250
generic clip finished
wanliAlex Jan 27, 2023
d3a1cae
generic clip finished
wanliAlex Jan 27, 2023
5c04f28
add test
wanliAlex Jan 27, 2023
f024861
add test
wanliAlex Jan 27, 2023
e6d47d5
add test
wanliAlex Jan 27, 2023
40914d9
add test
wanliAlex Jan 27, 2023
5cc7537
add test
wanliAlex Jan 27, 2023
3c4bbc6
add test
wanliAlex Jan 27, 2023
5e0197b
add test
wanliAlex Jan 27, 2023
138643b
add test
wanliAlex Jan 30, 2023
28fe006
add test
wanliAlex Jan 30, 2023
d113550
add test
wanliAlex Jan 30, 2023
636d1e3
Separate clip and open_clip load
wanliAlex Jan 30, 2023
e5278b5
Separate clip and open_clip load
wanliAlex Jan 30, 2023
1ae5948
Separate clip and open_clip load
wanliAlex Jan 30, 2023
da9f36f
add *args, **kwargs in sbert model class
wanliAlex Jan 30, 2023
0c4954f
add **kwargs in sbert, onnx sbert model class
wanliAlex Jan 30, 2023
c82c87c
typo fix
wanliAlex Jan 31, 2023
ffe06f1
revise error style!
wanliAlex Feb 1, 2023
5e9314e
remove space
wanliAlex Feb 1, 2023
c469835
change error message
wanliAlex Feb 1, 2023
ec7ca6d
change error message
wanliAlex Feb 1, 2023
7fa5c9d
revised based on pandu's comments
wanliAlex Feb 1, 2023
7326261
adding test pipelines
wanliAlex Feb 2, 2023
f166d65
test another document
wanliAlex Feb 2, 2023
ea35ff3
test another document
wanliAlex Feb 2, 2023
789203f
test another document
wanliAlex Feb 2, 2023
92f7a2f
test another document
wanliAlex Feb 2, 2023
9ea286c
test another document
wanliAlex Feb 2, 2023
e80e53b
test another document
wanliAlex Feb 2, 2023
e7057c7
change downloading path for clip
wanliAlex Feb 2, 2023
42410df
edit error
wanliAlex Feb 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 160 additions & 13 deletions src/marqo/s2_inference/clip_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# from torch import FloatTensor
# from typing import Any, Dict, List, Optional, Union
import os

import PIL.Image
import validators
import requests
import numpy as np
Expand All @@ -10,18 +12,53 @@
import open_clip
from multilingual_clip import pt_multilingual_clip
import transformers

from clip.model import build_model
from marqo.s2_inference.types import *
from marqo.s2_inference.logger import get_logger
import marqo.s2_inference.model_registry as model_registry
from marqo.s2_inference.errors import IncompatibleModelDeviceError, InvalidModelPropertiesError
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from marqo.s2_inference.processing.custom_clip_utils import HFTokenizer, download_pretrained_from_url
from torchvision.transforms import InterpolationMode

logger = get_logger(__name__)

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
BICUBIC = InterpolationMode.BICUBIC


def get_allowed_image_types():
return set(('.jpg', '.png', '.bmp', '.jpeg'))


def _convert_image_to_rgb(image: ImageType) -> ImageType:
# Take a PIL.Image.Image and return its RGB version
return image.convert("RGB")
jn2clark marked this conversation as resolved.
Show resolved Hide resolved


def _get_transform(n_px: int, image_mean:List[float] = None, image_std: List[float] = None) -> torch.Tensor:
'''This function returns a transform to preprocess the image. The processed image will be passed into
clip model for inference.
Args:
n_px: the size of the processed image
image_mean: the mean of the image used for normalization
image_std: the std of the image used for normalization

Returns:
the processed image tensor with shape (3, n_px, n_px)
'''
img_mean = image_mean or OPENAI_DATASET_MEAN
img_std = image_std or OPENAI_DATASET_STD
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize(img_mean, img_std),
])


def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]]) -> List[ImageType]:
"""takes in a list of strings, arrays or urls and either loads and/or converts to PIL
for the clip model
Expand All @@ -44,6 +81,7 @@ def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]]) ->

return results


def load_image_from_path(image_path: str) -> ImageType:
"""Loads an image into PIL from a string path that is either local or a url

Expand All @@ -70,6 +108,7 @@ def load_image_from_path(image_path: str) -> ImageType:

return img


def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType]) -> ImageType:
"""standardizes the input to be a PIL image

Expand All @@ -96,6 +135,7 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType]) -> ImageTy

return img


def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
# some logic to determine if something is an image or not
# assume the batch is the same type
Expand Down Expand Up @@ -140,6 +180,7 @@ def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
else:
raise UnidentifiedImageError(f"expected type Image or str for inputs but received type {type(thing)}")


class CLIP:

"""
Expand All @@ -156,15 +197,46 @@ def __init__(self, model_type: str = "ViT-B/32", device: str = 'cpu', embedding
self.processor = None
self.embedding_dimension = embedding_dim
self.truncate = truncate
self.model_properties = kwargs.get("model_properties", dict())

def load(self) -> None:

# https://github.com/openai/CLIP/issues/30
self.model, self.preprocess = clip.load(self.model_type, device='cpu', jit=False)
self.model = self.model.to(self.device)
self.tokenizer = clip.tokenize
self.model.eval()

path = self.model_properties.get("localpath", None) or self.model_properties.get("url",None)

if path is None:
# The original method to load the openai clip model
# https://github.com/openai/CLIP/issues/30
self.model, self.preprocess = clip.load(self.model_type, device='cpu', jit=False)
self.model = self.model.to(self.device)
self.tokenizer = clip.tokenize
else:
logger.info("Detecting custom clip model path. We use generic clip model loading.")
if os.path.isfile(path):
self.model_path = path
elif validators.url(path):
self.model_path = download_pretrained_from_url(path)
else:
raise InvalidModelPropertiesError(f"Marqo can not load the custom clip model."
f"The provided model path `{path}` is neither a local file nor a valid url."
f"Please check your provided model url and retry"
f"Check `https://docs.marqo.ai/0.0.12/Models-Reference/dense_retrieval/` for more info.")

self.jit = self.model_properties.get("jit", False)
self.model, self.preprocess = self.custom_clip_load()
self.tokenizer = clip.tokenize

self.model.eval()

pandu-k marked this conversation as resolved.
Show resolved Hide resolved

def custom_clip_load(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to duplicate the load code here from open_clip?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no we don't have to.

self.model_name = self.model_properties.get("name", None)

logger.info(f"The name of the custom clip model is {self.model_name}. We use openai clip load")
model, preprocess = clip.load(name=self.model_path, device="cpu", jit= self.jit)
model = model.to(self.device)
return model, preprocess


def _convert_output(self, output):

if self.device == 'cpu':
Expand All @@ -176,11 +248,12 @@ def _convert_output(self, output):
def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)


def encode_text(self, sentence: Union[str, List[str]], normalize = True) -> FloatTensor:

if self.model is None:
self.load()

text = self.tokenizer(sentence, truncate=self.truncate).to(self.device)

with torch.no_grad():
Expand Down Expand Up @@ -239,19 +312,91 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)


class FP16_CLIP(CLIP):
def __init__(self, model_type: str = "fp16/ViT-B/32", device: str = 'cuda', embedding_dim: int = None,
truncate: bool = True, **kwargs) -> None:
super().__init__(model_type, device, embedding_dim, truncate, **kwargs)
'''This class loads the provided clip model directly from cuda in float16 version. The inference time is halved
with very minor accuracy drop.
'''

pandu-k marked this conversation as resolved.
Show resolved Hide resolved
if not self.device.startswith("cuda"):
raise IncompatibleModelDeviceError(f"Marqo can not load the provided model `{self.model_type}`"
f"FP16 clip model `{self.model_type}` is only available with device `cuda`."
f"Please check you cuda availability or try the fp32 version `{self.model_type.replace('fp16/','')}`"
f"Check `https://docs.marqo.ai/0.0.13/Models-Reference/dense_retrieval/#generic-clip-models` for more info.")

self.model_name = self.model_type.replace("fp16/", "")


def load(self) -> None:
# https://github.com/openai/CLIP/issues/30
self.model, self.preprocess = clip.load(self.model_name, device='cuda', jit=False)
self.model = self.model.to(self.device)
self.tokenizer = clip.tokenize
self.model.eval()

class OPEN_CLIP(CLIP):
def __init__(self, model_type: str = "open_clip/ViT-B-32-quickgelu/laion400m_e32", device: str = 'cpu', embedding_dim: int = None,
truncate: bool = True, **kwargs) -> None:
super().__init__(model_type, device, embedding_dim, truncate , **kwargs)
self.model_name = model_type.split("/", 3)[1]
self.pretrained = model_type.split("/", 3)[2]
self.model_name = model_type.split("/", 3)[1] if model_type.startswith("open_clip/") else model_type
self.pretrained = model_type.split("/", 3)[2] if model_type.startswith("open_clip/") else model_type


def load(self) -> None:
# https://github.com/mlfoundations/open_clip
self.model, _, self.preprocess = open_clip.create_model_and_transforms(self.model_name, pretrained = self.pretrained, device=self.device, jit=False)
self.tokenizer = open_clip.get_tokenizer(self.model_name)
self.model.eval()
path = self.model_properties.get("localpath", None) or self.model_properties.get("url", None)

if path is None:
self.model, _, self.preprocess = open_clip.create_model_and_transforms(self.model_name,
pretrained=self.pretrained,
device=self.device, jit=False)
self.tokenizer = open_clip.get_tokenizer(self.model_name)
self.model.eval()
else:
logger.info("Detecting custom clip model path. We use generic clip model loading.")
if os.path.isfile(path):
self.model_path = path
elif validators.url(path):
self.model_path = download_pretrained_from_url(path)
else:
raise InvalidModelPropertiesError(f"Marqo can not load the custom clip model."
f"The provided model path `{path}` is neither a local file nor a valid url."
f"Please check your provided model url and retry."
f"Check `https://docs.marqo.ai/0.0.13/Models-Reference/dense_retrieval/#generic-clip-models` for more info.")

self.precision = self.model_properties.get("precision", "fp32")
self.jit = self.model_properties.get("jit", False)
self.mean = self.model_properties.get("mean", None)
self.std = self.model_properties.get("std", None)

self.model, self.preprocess = self.custom_clip_load()
self.tokenizer = self.load_tokenizer()

self.model.eval()


def custom_clip_load(self):
self.model_name = self.model_properties.get("name", None)


logger.info(f"The name of the custom clip model is {self.model_name}. We use open_clip load")
model, _, preprocess = open_clip.create_model_and_transforms(model_name=self.model_name, jit = self.jit, pretrained=self.model_path, precision = self.precision,
image_mean=self.mean, image_std=self.std, device = self.device)

return model, preprocess


def load_tokenizer(self):
tokenizer_name = self.model_properties.get("tokenizer", "clip")

if tokenizer_name == "clip":
return clip.tokenize
else:
logger.info(f"Custom HFTokenizer is provided. Loading...")
return HFTokenizer(tokenizer_name)

def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatTensor:

Expand Down Expand Up @@ -302,6 +447,7 @@ def load(self) -> None:
self.textual_model.eval()
self.visual_model.eval()


def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatTensor:

if self.textual_model is None:
Expand All @@ -317,6 +463,7 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT

return self._convert_output(outputs)


def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
normalize=True) -> FloatTensor:

Expand Down
5 changes: 5 additions & 0 deletions src/marqo/s2_inference/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,9 @@ class RerankerNameError(S2InferenceError):


class ModelNotInCacheError(S2InferenceError):
pass

# Raise an ERROR if the model is only available with "cpu" or "cuda" but
# the other one is provided
class IncompatibleModelDeviceError(S2InferenceError):
pass
3 changes: 1 addition & 2 deletions src/marqo/s2_inference/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from marqo.s2_inference.sbert_utils import Model
from marqo.s2_inference.types import Union, FloatTensor, List

from marqo.s2_inference.logger import get_logger
logger = get_logger(__name__)

Expand Down Expand Up @@ -76,4 +75,4 @@ def mean_pooling(self, model_output, attention_mask):
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def cls_pooling(self, model_output, attention_mask):
return model_output[0][:,0]
return model_output[0][:,0]
31 changes: 30 additions & 1 deletion src/marqo/s2_inference/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from marqo.s2_inference.sbert_onnx_utils import SBERT_ONNX
from marqo.s2_inference.sbert_utils import SBERT, TEST
from marqo.s2_inference.random_utils import Random
from marqo.s2_inference.clip_utils import CLIP, OPEN_CLIP, MULTILINGUAL_CLIP
from marqo.s2_inference.clip_utils import CLIP, OPEN_CLIP, MULTILINGUAL_CLIP, FP16_CLIP
from marqo.s2_inference.types import Any, Dict, List, Optional, Union, FloatTensor
from marqo.s2_inference.onnx_clip_utils import CLIP_ONNX

Expand Down Expand Up @@ -1532,6 +1532,32 @@ def _get_onnx_clip_properties() -> Dict:
}
return ONNX_CLIP_MODEL_PROPERTIES


def _get_fp16_clip_properties() -> Dict:
FP16_CLIP_MODEL_PROPERTIES = {
"fp16/ViT-L/14": {
"name": "fp16/ViT-L/14",
"dimensions": 768,
"type": "fp16_clip",
"notes": "The faster version (fp16, load from `cuda`) of openai clip model"
},
'fp16/ViT-B/32':
{"name": "fp16/ViT-B/32",
"dimensions": 512,
"notes": "The faster version (fp16, load from `cuda`) of openai clip model",
"type": "fp16_clip",
},
'fp16/ViT-B/16':
{"name": "fp16/ViT-B/16",
"dimensions": 512,
"notes": "The faster version (fp16, load from `cuda`) of openai clip model",
"type": "fp16_clip",
},
}

return FP16_CLIP_MODEL_PROPERTIES


def _get_random_properties() -> Dict:
RANDOM_MODEL_PROPERTIES = {
"random":
Expand Down Expand Up @@ -1570,6 +1596,7 @@ def _get_model_load_mappings() -> Dict:
'sbert_onnx':SBERT_ONNX,
'clip_onnx': CLIP_ONNX,
"multilingual_clip" : MULTILINGUAL_CLIP,
"fp16_clip": FP16_CLIP,
'random':Random,
'hf':HF_MODEL}

Expand All @@ -1587,6 +1614,7 @@ def load_model_properties() -> Dict:
open_clip_model_properties = _get_open_clip_properties()
onnx_clip_model_properties = _get_onnx_clip_properties()
multilingual_clip_model_properties = _get_multilingual_clip_properties()
fp16_clip_model_properties = _get_fp16_clip_properties()

# combine the above dicts
model_properties = dict(clip_model_properties.items())
Expand All @@ -1598,6 +1626,7 @@ def load_model_properties() -> Dict:
model_properties.update(open_clip_model_properties)
model_properties.update(onnx_clip_model_properties)
model_properties.update(multilingual_clip_model_properties)
model_properties.update(fp16_clip_model_properties)

all_properties = dict()
all_properties['models'] = model_properties
Expand Down
Loading