Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix image download headers regresison and fix png image issue #1022

Merged
merged 31 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
deaecf1
Finish add documents
wanliAlex Oct 23, 2024
8e1bf32
Finish search
wanliAlex Oct 23, 2024
574e61d
Finish development
wanliAlex Oct 23, 2024
dfbe81c
Fix more than 2 modalities bugs in search
wanliAlex Oct 24, 2024
3cc7a2f
Need to fix infer issue
wanliAlex Oct 24, 2024
aa2b1d6
Revert changes in src/marqo/s2_inference/ and reconsider parameters p…
wanliAlex Oct 24, 2024
2dbdb00
Fix tests
wanliAlex Oct 24, 2024
6e4b924
Fix hybrid
wanliAlex Oct 24, 2024
0f84ba6
Fix hybrid tests
wanliAlex Oct 24, 2024
8afca5f
Fix embed
wanliAlex Oct 24, 2024
414df74
Fix embed
wanliAlex Oct 24, 2024
b5e2195
Add add_documents tests and search tests
wanliAlex Oct 24, 2024
9b2a08a
Respond to Farshid's comments
wanliAlex Oct 24, 2024
31048f0
Replace all the image_download_headers with media_download_headers
wanliAlex Oct 24, 2024
2e35bac
Fix tests
wanliAlex Oct 24, 2024
42997d8
Catch mainline
wanliAlex Oct 24, 2024
2cc4622
Add language bind modality tests
wanliAlex Oct 24, 2024
bc35efb
Fix tests
wanliAlex Oct 24, 2024
9cb1edd
Catch mainline
wanliAlex Oct 24, 2024
674bb17
Fix headers for media
wanliAlex Oct 24, 2024
27da146
Fix headers for media
wanliAlex Oct 24, 2024
c37fe9e
Fix headers for media
wanliAlex Oct 24, 2024
a35b1c9
Fix media download headers for video and audio
wanliAlex Oct 25, 2024
532cc4e
Fix tests
wanliAlex Oct 25, 2024
f85ee8b
Convert image to RGB for languagebind
wanliAlex Oct 25, 2024
bace6fd
Fix tests
wanliAlex Oct 25, 2024
72e4b36
Delete a test
wanliAlex Oct 25, 2024
96e2231
Add back the test
wanliAlex Oct 25, 2024
2353e6a
Change largemodel tests logic
wanliAlex Oct 25, 2024
573b469
Fix tests
wanliAlex Oct 25, 2024
a5ea947
Fix tests
wanliAlex Oct 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/marqo/api/models/add_docs_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Config:
tensorFields: Optional[List] = None
useExistingTensors: bool = False
imageDownloadHeaders: dict = Field(default_factory=dict)
mediaDownloadHeaders: Optional[dict] = None
modelAuth: Optional[ModelAuth] = None
mappings: Optional[dict] = None
documents: Union[Sequence[Union[dict, Any]], np.ndarray]
Expand All @@ -38,3 +39,22 @@ def validate_thread_counts(cls, values):
if media_count is not None and image_count != read_env_vars_and_defaults_ints(EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST):
raise ValueError("Cannot set both imageDownloadThreadCount and mediaDownloadThreadCount")
return values

@root_validator(skip_on_failure=True)
def _validate_image_download_headers_and_media_download_headers(cls, values):
"""Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.

If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
rest of the code.

imageDownloadHeaders is deprecated and will be removed in the future.
"""
image_download_headers = values.get('imageDownloadHeaders')
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
"'imageDownloadHeaders' is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
return values
32 changes: 27 additions & 5 deletions src/marqo/api/models/embed_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@
import pydantic
from typing import Union, List, Dict, Optional, Any

from pydantic import Field, root_validator

from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.api_models import BaseMarqoModel
from marqo.base_model import MarqoBaseModel
from marqo.core.embed.embed import EmbedContentType



class EmbedRequest(BaseMarqoModel):
class EmbedRequest(MarqoBaseModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
image_download_headers: Optional[Dict] = None
image_download_headers: Optional[Dict] = Field(default=None, alias="imageDownloadHeaders")
mediaDownloadHeaders: Optional[Dict] = None
modelAuth: Optional[ModelAuth] = None
content_type: Optional[EmbedContentType] = EmbedContentType.Query
content_type: Optional[EmbedContentType] = Field(default=EmbedContentType.Query, alias="contentType")

@pydantic.validator('content')
def validate_content(cls, value):
Expand Down Expand Up @@ -47,4 +50,23 @@ def validate_content(cls, value):
else:
raise ValueError("Embed content should be a string, a dictionary, or a list of strings or dictionaries")

return value
return value

@root_validator(skip_on_failure=True)
def _validate_image_download_headers_and_media_download_headers(cls, values):
"""Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.

If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
rest of the code.

imageDownloadHeaders is deprecated and will be removed in the future.
"""
image_download_headers = values.get('imageDownloadHeaders')
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
farshidz marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
"'imageDownloadHeaders' is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
return values
13 changes: 7 additions & 6 deletions src/marqo/core/embed/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def validate_default_device(cls, value):
return value

def embed_content(
self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
index_name: str, device: str = None, image_download_headers: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
content_type: Optional[EmbedContentType] = EmbedContentType.Query
) -> Dict:
self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
index_name: str, device: str = None,
media_download_headers: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
content_type: Optional[EmbedContentType] = EmbedContentType.Query
) -> Dict:
"""
Use the index's model to embed the content

Expand Down Expand Up @@ -105,7 +106,7 @@ def embed_content(
BulkSearchQueryEntity(
q=content_entry,
index=marqo_index,
image_download_headers=image_download_headers,
mediaDownloadHeaders=media_download_headers,
modelAuth=model_auth,
text_query_prefix=prefix
# TODO: Check if it's fine that we leave out the other parameters
Expand Down
24 changes: 12 additions & 12 deletions src/marqo/core/inference/embedding_models/abstract_clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.core.inference.embedding_models.abstract_embedding_model import AbstractEmbeddingModel
from marqo.core.inference.embedding_models.image_download import (_is_image, format_and_load_CLIP_images,
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.s2_inference.logger import get_logger
from marqo.s2_inference.types import *
Expand Down Expand Up @@ -50,11 +50,11 @@ def encode_text(self, inputs: Union[str, List[str]], normalize: bool = True) ->
pass

@abstractmethod
def encode_image(self, inputs, normalize: bool = True, image_download_headers: dict = None) -> np.ndarray:
def encode_image(self, inputs, normalize: bool = True, media_download_headers: dict = None) -> np.ndarray:
pass

def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
default: str = 'text', normalize=True, **kwargs) -> np.ndarray:
def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> np.ndarray:
default = "text"
infer = kwargs.pop('infer', True)
if infer and _is_image(inputs):
is_image = True
Expand All @@ -68,8 +68,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],

if is_image:
logger.debug('image')
image_download_headers = kwargs.get("image_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
media_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
Expand All @@ -85,27 +85,27 @@ def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)

def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
image_download_headers: Optional[Dict] = None) -> Tensor:
media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.

Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
image_download_headers (Optional[Dict]): headers for the image download
media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
if image_download_headers is None:
image_download_headers = dict()
if media_download_headers is None:
media_download_headers = dict()

# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
= format_and_load_CLIP_images(images, image_download_headers)
= format_and_load_CLIP_images(images, media_download_headers)
else:
image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]

image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
Expand Down
Loading
Loading