Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix image download headers regresison and fix png image issue #1022

Merged
merged 31 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
deaecf1
Finish add documents
wanliAlex Oct 23, 2024
8e1bf32
Finish search
wanliAlex Oct 23, 2024
574e61d
Finish development
wanliAlex Oct 23, 2024
dfbe81c
Fix more than 2 modalities bugs in search
wanliAlex Oct 24, 2024
3cc7a2f
Need to fix infer issue
wanliAlex Oct 24, 2024
aa2b1d6
Revert changes in src/marqo/s2_inference/ and reconsider parameters p…
wanliAlex Oct 24, 2024
2dbdb00
Fix tests
wanliAlex Oct 24, 2024
6e4b924
Fix hybrid
wanliAlex Oct 24, 2024
0f84ba6
Fix hybrid tests
wanliAlex Oct 24, 2024
8afca5f
Fix embed
wanliAlex Oct 24, 2024
414df74
Fix embed
wanliAlex Oct 24, 2024
b5e2195
Add add_documents tests and search tests
wanliAlex Oct 24, 2024
9b2a08a
Respond to Farshid's comments
wanliAlex Oct 24, 2024
31048f0
Replace all the image_download_headers with media_download_headers
wanliAlex Oct 24, 2024
2e35bac
Fix tests
wanliAlex Oct 24, 2024
42997d8
Catch mainline
wanliAlex Oct 24, 2024
2cc4622
Add language bind modality tests
wanliAlex Oct 24, 2024
bc35efb
Fix tests
wanliAlex Oct 24, 2024
9cb1edd
Catch mainline
wanliAlex Oct 24, 2024
674bb17
Fix headers for media
wanliAlex Oct 24, 2024
27da146
Fix headers for media
wanliAlex Oct 24, 2024
c37fe9e
Fix headers for media
wanliAlex Oct 24, 2024
a35b1c9
Fix media download headers for video and audio
wanliAlex Oct 25, 2024
532cc4e
Fix tests
wanliAlex Oct 25, 2024
f85ee8b
Convert image to RGB for languagebind
wanliAlex Oct 25, 2024
bace6fd
Fix tests
wanliAlex Oct 25, 2024
72e4b36
Delete a test
wanliAlex Oct 25, 2024
96e2231
Add back the test
wanliAlex Oct 25, 2024
2353e6a
Change largemodel tests logic
wanliAlex Oct 25, 2024
573b469
Fix tests
wanliAlex Oct 25, 2024
a5ea947
Fix tests
wanliAlex Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/marqo/core/inference/embedding_models/abstract_clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.core.inference.embedding_models.abstract_embedding_model import AbstractEmbeddingModel
from marqo.core.inference.embedding_models.image_download import (_is_image, format_and_load_CLIP_images,
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.s2_inference.logger import get_logger
from marqo.s2_inference.types import *
Expand Down Expand Up @@ -50,7 +50,7 @@ def encode_text(self, inputs: Union[str, List[str]], normalize: bool = True) ->
pass

@abstractmethod
def encode_image(self, inputs, normalize: bool = True, image_download_headers: dict = None) -> np.ndarray:
def encode_image(self, inputs, normalize: bool = True, media_download_headers: dict = None) -> np.ndarray:
pass

def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> np.ndarray:
Expand All @@ -68,8 +68,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], nor

if is_image:
logger.debug('image')
image_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
media_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
Expand All @@ -85,27 +85,27 @@ def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)

def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
image_download_headers: Optional[Dict] = None) -> Tensor:
media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.

Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
image_download_headers (Optional[Dict]): headers for the image download
media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
if image_download_headers is None:
image_download_headers = dict()
if media_download_headers is None:
media_download_headers = dict()

# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
= format_and_load_CLIP_images(images, image_download_headers)
= format_and_load_CLIP_images(images, media_download_headers)
else:
image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]

image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
Expand Down
236 changes: 0 additions & 236 deletions src/marqo/core/inference/embedding_models/image_download.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/marqo/core/inference/embedding_models/open_clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ def _download_from_repo(self):
return model_file_path

def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
image_download_headers: Optional[Dict] = None,
media_download_headers: Optional[Dict] = None,
normalize=True) -> FloatTensor:

self.image_input_processed: Tensor = self._preprocess_images(images, image_download_headers)
self.image_input_processed: Tensor = self._preprocess_images(images, media_download_headers)

with torch.no_grad():
if self.device.startswith("cuda"):
Expand Down
20 changes: 10 additions & 10 deletions src/marqo/core/inference/image_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
raise UnidentifiedImageError(f"expected type Image or str for inputs but received type {type(thing)}")


def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], media_download_headers: dict) -> List[
ImageType]:
"""takes in a list of strings, arrays or urls and either loads and/or converts to PIL
for the clip model
Expand All @@ -90,13 +90,13 @@ def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], im

results = []
for image in images:
results.append(format_and_load_CLIP_image(image, image_download_headers))
results.append(format_and_load_CLIP_image(image, media_download_headers))

return results


def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
image_download_headers: dict) -> Union[ImageType, Tensor]:
media_download_headers: dict) -> Union[ImageType, Tensor]:
"""standardizes the input to be a PIL image

Args:
Expand All @@ -113,7 +113,7 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
"""
# check for the input type
if isinstance(image, str):
img = load_image_from_path(image, image_download_headers)
img = load_image_from_path(image, media_download_headers)
elif isinstance(image, np.ndarray):
img = Image.fromarray(image.astype('uint8'), 'RGB')
elif isinstance(image, torch.Tensor):
Expand All @@ -127,13 +127,13 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
return img


def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
def load_image_from_path(image_path: str, media_download_headers: dict, timeout_ms=3000,
metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
"""Loads an image into PIL from a string path that is either local or a url

Args:
image_path (str): Local or remote path to image.
image_download_headers (dict): header for the image download
media_download_headers (dict): header for the image download
timeout_ms (int): timeout (in milliseconds), for the whole request
Raises:
ValueError: If the local path is invalid, and is not a url
Expand All @@ -148,7 +148,7 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
if metrics_obj is not None:
metrics_obj.start(f"image_download.{image_path}")
try:
img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
img_io: BytesIO = download_image_from_url(image_path, media_download_headers, timeout_ms)
img = Image.open(img_io)
except ImageDownloadError as e:
raise UnidentifiedImageError(str(e)) from e
Expand All @@ -167,12 +167,12 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
return img


def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
def download_image_from_url(image_path: str, media_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
"""Download an image from a URL and return a PIL image using pycurl.

Args:
image_path (str): URL to the image.
image_download_headers (dict): Headers for the image download.
media_download_headers (dict): Headers for the image download.
timeout_ms (int): Timeout in milliseconds, for the whole request.

Returns:
Expand All @@ -199,7 +199,7 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)

headers = DEFAULT_HEADERS.copy()
headers.update(image_download_headers)
headers.update(media_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])

try:
Expand Down
Loading
Loading