From deaecf1b0ee14e3663e7acad046e8428b7b94aae Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Wed, 23 Oct 2024 21:43:41 +1100
Subject: [PATCH 01/29] Finish add documents
---
src/marqo/api/models/add_docs_objects.py | 20 ++++
src/marqo/core/models/add_docs_params.py | 4 +-
.../s2_inference/multimodal_model_load.py | 12 +-
src/marqo/tensor_search/add_docs.py | 108 ++++++++++--------
src/marqo/tensor_search/web/api_utils.py | 16 +--
5 files changed, 100 insertions(+), 60 deletions(-)
diff --git a/src/marqo/api/models/add_docs_objects.py b/src/marqo/api/models/add_docs_objects.py
index 5d7a33695..2174753e6 100644
--- a/src/marqo/api/models/add_docs_objects.py
+++ b/src/marqo/api/models/add_docs_objects.py
@@ -22,6 +22,7 @@ class Config:
tensorFields: Optional[List] = None
useExistingTensors: bool = False
imageDownloadHeaders: dict = Field(default_factory=dict)
+ mediaDownloadHeaders: Optional[dict] = None
modelAuth: Optional[ModelAuth] = None
mappings: Optional[dict] = None
documents: Union[Sequence[Union[dict, Any]], np.ndarray]
@@ -38,3 +39,22 @@ def validate_thread_counts(cls, values):
if media_count is not None and image_count != read_env_vars_and_defaults_ints(EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST):
raise ValueError("Cannot set both imageDownloadThreadCount and mediaDownloadThreadCount")
return values
+
+ @root_validator(skip_on_failure=True)
+ def _validate_image_download_headers_and_media_download_headers(cls, values):
+ """Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
+
+ If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
+ rest of the code.
+
+ imageDownloadHeaders is deprecated and will be removed in the future.
+ """
+ image_download_headers = values.get('imageDownloadHeaders')
+ media_download_headers = values.get('mediaDownloadHeaders')
+ if image_download_headers and media_download_headers:
+ raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
+ "The imageDownloadHeaders is deprecated and will be removed in the future. "
+ "Use mediaDownloadHeaders instead.")
+ if image_download_headers:
+ values['mediaDownloadHeaders'] = image_download_headers
+ return values
diff --git a/src/marqo/core/models/add_docs_params.py b/src/marqo/core/models/add_docs_params.py
index 557bf166b..66cf12185 100644
--- a/src/marqo/core/models/add_docs_params.py
+++ b/src/marqo/core/models/add_docs_params.py
@@ -31,7 +31,7 @@ class AddDocsParams(BaseModel):
device: Device used to carry out the document update, if `None` is given, it will be determined by
EnvVars.MARQO_BEST_AVAILABLE_DEVICE
image_download_thread_count: number of threads used to concurrently download images
- image_download_headers: headers to authenticate image download
+ media_download_headers: headers to authenticate media download requests
mappings: a dictionary used to handle all the object field content in the doc,
e.g., multimodal_combination field
model_auth: an object used to authorise downloading an object from a datastore
@@ -53,7 +53,7 @@ class Config:
image_download_thread_count: int = Field(default_factory=lambda: read_env_vars_and_defaults_ints(
EnvVars.MARQO_IMAGE_DOWNLOAD_THREAD_COUNT_PER_REQUEST))
media_download_thread_count: Optional[int]
- image_download_headers: dict = Field(default_factory=dict)
+ media_download_headers: Optional[dict] = None
use_existing_tensors: bool = False
mappings: Optional[dict] = None
model_auth: Optional[ModelAuth] = None
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 173630c22..0c7ed9431 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -11,10 +11,11 @@
from pydantic import BaseModel
from enum import Enum
from abc import ABC, abstractmethod
-from typing import List, Dict, Any, Union
+from typing import List, Dict, Any, Union, Optional
from PIL.Image import Image
import torch
from urllib.parse import quote
+from marqo.core.inference.image_download import DEFAULT_HEADERS
from marqo.s2_inference.multimodal_model_load import *
@@ -130,8 +131,9 @@ def encode(self, content, modality, **kwargs):
@contextmanager
-def fetch_content_sample(url, sample_size=10240): # 10 KB
- response = requests.get(url, stream=True)
+def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sample_size=10240): # 10 KB
+ # It's ok to pass None to requests.get() for headers and it won't change the default headers
+ response = requests.get(url, stream=True, headers=media_download_headers)
buffer = io.BytesIO()
try:
for chunk in response.iter_content(chunk_size=min(sample_size, 8192)):
@@ -145,7 +147,7 @@ def fetch_content_sample(url, sample_size=10240): # 10 KB
response.close()
-def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
+def infer_modality(content: Union[str, List[str], bytes], media_download_headers: Optional[dict] = None) -> Modality:
"""
Infer the modality of the content. Video, audio, image or text.
"""
@@ -167,7 +169,7 @@ def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
if validate_url(encoded_url):
# Use context manager to handle content sample
try:
- with fetch_content_sample(encoded_url) as sample:
+ with fetch_content_sample(encoded_url, media_download_headers) as sample:
mime = magic.from_buffer(sample.read(), mime=True)
if mime.startswith('image/'):
return Modality.IMAGE
diff --git a/src/marqo/tensor_search/add_docs.py b/src/marqo/tensor_search/add_docs.py
index a8dbded3d..c87a7c78f 100644
--- a/src/marqo/tensor_search/add_docs.py
+++ b/src/marqo/tensor_search/add_docs.py
@@ -42,7 +42,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
image_download_headers: dict,
device: str = None,
media_field_types_mapping: Optional[Dict[str, FieldType]] = None,
- download_headers: Optional[Dict] = None, # Optional for now
+ media_download_headers: Optional[Dict] = None,
metric_obj: Optional[RequestMetrics] = None,
preprocessors: Optional[Dict[str, Compose]] = None,
marqo_index_type: Optional[IndexType] = None,
@@ -59,7 +59,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
image_repo: dictionary that will be mutated by this thread. It will add PIL images
as values and the URLs as keys
tensor_fields: A tuple of tensor_fields. Images will be downloaded for these fields only.
- image_download_headers: A dict of headers for image download. Can be used
+ media_download_headers: A dict of headers for image download. Can be used
to authenticate image downloads
force_download: If True, skip the _is_image check and download the fields as images.
Side Effects:
@@ -93,7 +93,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
if isinstance(doc[field], str) or force_download:
try:
- inferred_modality = infer_modality(doc[field])
+ inferred_modality = infer_modality(doc[field], media_download_headers)
except MediaDownloadError as e:
if is_structured_index and media_field_types_mapping[field] == FieldType.ImagePointer:
# Continue processing for structured indexes with image fields
@@ -222,24 +222,24 @@ def download_and_preprocess_multimedia_content(
) -> ContextManager[dict]:
thread_count = _determine_thread_count(marqo_index, add_docs_params)
- media_repo = process_batch(docs=docs,
- thread_count=thread_count,
- tensor_fields=list(media_field_types_mapping.keys()),
- media_field_types_mapping=media_field_types_mapping,
- image_download_headers=add_docs_params.image_download_headers,
- download_headers=None, # TODO verify if this is used
- marqo_index_type=marqo_index.type,
- device=add_docs_params.device,
- marqo_index_model=marqo_index.model,
- model_name=marqo_index.model.name,
- model_properties=marqo_index.model.properties,
- normalize_embeddings=marqo_index.normalize_embeddings,
- model_auth=add_docs_params.model_auth,
- patch_method_exists=marqo_index.image_preprocessing.patch_method is not None,
- audio_preprocessing=marqo_index.audio_preprocessing,
- video_preprocessing=marqo_index.video_preprocessing,
- force_download=False, # TODO verify if this is used
- )
+ media_repo = process_batch(
+ docs=docs,
+ thread_count=thread_count,
+ tensor_fields=list(media_field_types_mapping.keys()),
+ media_field_types_mapping=media_field_types_mapping,
+ media_download_headers=add_docs_params.media_download_headers,
+ marqo_index_type=marqo_index.type,
+ device=add_docs_params.device,
+ marqo_index_model=marqo_index.model,
+ model_name=marqo_index.model.name,
+ model_properties=marqo_index.model.properties,
+ normalize_embeddings=marqo_index.normalize_embeddings,
+ model_auth=add_docs_params.model_auth,
+ patch_method_exists=marqo_index.image_preprocessing.patch_method is not None,
+ audio_preprocessing=marqo_index.audio_preprocessing,
+ video_preprocessing=marqo_index.video_preprocessing,
+ force_download=False, # TODO verify if this is used
+ )
try:
yield media_repo
@@ -293,7 +293,7 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
model_name: str,
normalize_embeddings: bool,
media_field_types_mapping: Optional[Dict[str, FieldType]],
- download_headers: Optional[Dict] = None, # Optional for now
+ media_download_headers: Optional[Dict] = None, # Optional for now
model_properties: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
device: Optional[str] = None,
@@ -305,11 +305,25 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
force_download: bool = False
) -> ContextManager[dict]:
media_repo = {} # for image/video/audio
- media_repo = process_batch(docs, thread_count, tensor_fields, image_download_headers,
- model_name, normalize_embeddings, force_download,
- media_field_types_mapping, download_headers, model_properties, model_auth,
- device, patch_method_exists, marqo_index_type, marqo_index_model,
- audio_preprocessing, video_preprocessing)
+ media_repo = process_batch(
+ docs = docs,
+ thread_count = thread_count,
+ tensor_fields = tensor_fields,
+ image_download_headers = image_download_headers,
+ model_name = model_name,
+ normalize_embeddings = normalize_embeddings,
+ force_download = force_download,
+ media_field_types_mapping = media_field_types_mapping,
+ media_download_headers = media_download_headers,
+ model_properties = model_properties,
+ model_auth = model_auth,
+ device = device,
+ patch_method_exists = patch_method_exists,
+ marqo_index_type = marqo_index_type,
+ marqo_index_model = marqo_index_model,
+ audio_preprocessing = audio_preprocessing,
+ video_preprocessing = video_preprocessing
+ )
try:
yield media_repo
@@ -325,11 +339,13 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
def process_batch(docs: List[dict], thread_count: int, tensor_fields: List[str],
image_download_headers: dict, model_name: str, normalize_embeddings: bool,
force_download: bool, media_field_types_mapping: Optional[Dict[str, FieldType]],
- download_headers: Optional[Dict], model_properties: Optional[Dict],
+ model_properties: Optional[Dict],
model_auth: Optional[ModelAuth], device: Optional[str],
patch_method_exists: bool, marqo_index_type: Optional[IndexType], marqo_index_model: Optional[Model],
+ media_download_headers: Optional[Dict] = None,
audio_preprocessing: Optional[AudioPreProcessing] = None,
video_preprocessing: Optional[VideoPreProcessing] = None) -> dict:
+
docs_per_thread = math.ceil(len(docs) / thread_count)
copied = copy.deepcopy(docs)
@@ -349,25 +365,27 @@ def process_batch(docs: List[dict], thread_count: int, tensor_fields: List[str],
# Consider replacing below with:
# thread_allocated_docs = [copied[i: i + docs_per_thread] for i in range(0, len(copied), docs_per_thread)]
thread_allocated_docs = [copied[i: i + docs_per_thread] for i in range(len(copied))[::docs_per_thread]]
- download_headers = download_headers if download_headers else {}
with ThreadPoolExecutor(max_workers=len(thread_allocated_docs)) as executor:
- futures = [executor.submit(threaded_download_and_preprocess_content,
- allocation,
- media_repo,
- tensor_fields,
- image_download_headers,
- device,
- media_field_types_mapping,
- download_headers,
- m[i],
- preprocessors,
- marqo_index_type,
- marqo_index_model,
- audio_preprocessing,
- video_preprocessing,
- force_download)
- for i, allocation in enumerate(thread_allocated_docs)]
+ futures = [
+ executor.submit(
+ threaded_download_and_preprocess_content,
+ allocation,
+ media_repo,
+ tensor_fields,
+ image_download_headers,
+ device,
+ media_field_types_mapping,
+ media_download_headers,
+ m[i],
+ preprocessors,
+ marqo_index_type,
+ marqo_index_model,
+ audio_preprocessing,
+ video_preprocessing,
+ force_download)
+ for i, allocation in enumerate(thread_allocated_docs)
+ ]
# Unhandled exceptions will be raised here.
# We only raise the first exception if there are multiple exceptions
diff --git a/src/marqo/tensor_search/web/api_utils.py b/src/marqo/tensor_search/web/api_utils.py
index 0c2ab4d4e..1db7cea68 100644
--- a/src/marqo/tensor_search/web/api_utils.py
+++ b/src/marqo/tensor_search/web/api_utils.py
@@ -50,27 +50,27 @@ def translate_api_device(device: Optional[str]) -> Optional[str]:
f"Acceptable device types: {acceptable_devices}")
-def decode_image_download_headers(image_download_headers: Optional[str] = None) -> dict:
+def decode_media_download_headers(media_download_headers: Optional[str] = None) -> dict:
"""Decodes an image download header string into a Python dict
Args:
- image_download_headers: JSON-serialised, URL encoded header dictionary
+ media_download_headers: JSON-serialised, URL encoded header dictionary
Returns:
- image_download_headers as a dict
+ media_download_headers as a dict
Raises:
InvalidArgError if there is trouble parsing the dictionary
"""
- if not image_download_headers:
+ if not media_download_headers:
return dict()
else:
try:
- as_str = urllib.parse.unquote_plus(image_download_headers)
+ as_str = urllib.parse.unquote_plus(media_download_headers)
as_dict = json.loads(as_str)
return as_dict
except json.JSONDecodeError as e:
- raise InvalidArgError(f"Error parsing image_download_headers. Message: {e}")
+ raise InvalidArgError(f"Error parsing media_download_headers. Message: {e}")
def decode_query_string_model_auth(model_auth: Optional[str] = None) -> Optional[ModelAuth]:
@@ -130,14 +130,14 @@ def add_docs_params_orchestrator(index_name: str, body: Union[AddDocsBodyParams,
tensor_fields = body.tensorFields
use_existing_tensors = body.useExistingTensors
model_auth = body.modelAuth
- image_download_headers = body.imageDownloadHeaders
+ media_download_headers = body.mediaDownloadHeaders
image_download_thread_count = body.imageDownloadThreadCount
text_chunk_prefix = body.textChunkPrefix
return AddDocsParams(
index_name=index_name, docs=docs,
device=device, tensor_fields=tensor_fields,
- use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers,
+ use_existing_tensors=use_existing_tensors, media_download_headers=media_download_headers,
image_download_thread_count=image_download_thread_count,
mappings=mappings, model_auth=model_auth, text_chunk_prefix=text_chunk_prefix,
batch_vectorisation_mode=body.batchVectorisationMode,
From 8e1bf32a99a1bef8572b638f3b4cea84c44b0738 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 01:28:44 +1100
Subject: [PATCH 02/29] Finish search
---
src/marqo/s2_inference/s2_inference.py | 84 +++++++---
src/marqo/tensor_search/api.py | 2 +-
src/marqo/tensor_search/models/api_models.py | 25 ++-
src/marqo/tensor_search/models/search.py | 43 ++++-
src/marqo/tensor_search/tensor_search.py | 157 ++++++++++---------
5 files changed, 206 insertions(+), 105 deletions(-)
diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py
index fc97d5300..874565efa 100644
--- a/src/marqo/s2_inference/s2_inference.py
+++ b/src/marqo/s2_inference/s2_inference.py
@@ -47,8 +47,28 @@
def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[bytes]],
model_properties: dict = None,
- device: str = None, normalize_embeddings: bool = get_default_normalization(),
- model_auth: ModelAuth = None, enable_cache: bool = False, modality: Modality = Modality.TEXT, **kwargs,) -> List[List[float]]:
+ device: str = None,
+ normalize_embeddings: bool = get_default_normalization(),
+ model_auth: ModelAuth = None,
+ enable_cache: bool = False,
+ modality: Modality = Modality.TEXT,
+ media_download_headers: Optional[Dict] = None,
+ infer: bool = False
+ ) -> List[List[float]]:
+ """Vectorise the given content using the given model.
+
+ Args:
+ model_name: The name of the model to use.
+ content: The content to vectorise.
+ model_properties: The properties of the model to use.
+ device: The device to use.
+ normalize_embeddings: Whether to normalize the embeddings.
+ model_auth: The model authorisation details.
+ enable_cache: Whether to enable the inference cache.
+ modality: The modality of the content.
+ media_download_headers: The media download headers.
+ infer: Whether to infer the modality. Deprecated and should be replaced by modality.
+ """
if not device:
raise InternalError(message=f"vectorise (internal function) cannot be called without setting device!")
@@ -63,25 +83,37 @@ def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[
model = _available_models[model_cache_key][AvailableModelsKey.model]
if _marqo_inference_cache.is_enabled() and enable_cache:
- return _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_with_cache(
+ model, model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers, infer
+ )
else:
- return _vectorise_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_without_cache(
+ model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers, infer
+ )
-def _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
+def _vectorise_with_cache(model, model_cache_key: str, content, normalize_embeddings: bool, modality: Modality,
+ media_download_headers: Optional[Dict], infer: bool):
if isinstance(content, str):
vectorised = _marqo_inference_cache.get(model_cache_key, content)
if vectorised is None:
- vectorised = _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ vectorised = _encode_without_cache(
+ model_cache_key, content, normalize_embeddings, modality, media_download_headers, infer
+ )
_marqo_inference_cache.set(model_cache_key, content, vectorised[0])
else:
vectorised = _convert_cached_embeddings_to_output(vectorised)
return vectorised
elif isinstance(content, list):
- return _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_list_with_cache(
+ model, model_cache_key, content, normalize_embeddings, modality, media_download_headers, infer
+ )
else:
raise TypeError(f"Unsupported content type: {type(content).__name__}")
-def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
+def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers, infer):
contents_to_vectorise = []
cached_output = []
@@ -97,7 +129,8 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
contents_to_vectorise.append(content_item)
if contents_to_vectorise:
- vectorised_outputs = _encode_without_cache(model_cache_key, contents_to_vectorise, normalize_embeddings, modality, **kwargs)
+ vectorised_outputs = _encode_without_cache(
+ model_cache_key, contents_to_vectorise, normalize_embeddings, modality, media_download_headers, infer)
# Cache the vectorised outputs
for content_item, vectorised_output in zip(contents_to_vectorise, vectorised_outputs):
if isinstance(content_item, str):
@@ -110,20 +143,32 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
return vectorised_outputs
-def _vectorise_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
- return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
-def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
+def _vectorise_without_cache(
+ model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
+ normalize_embeddings: bool, modality: Modality,
+ media_download_headers: Optional[Dict], infer: bool
+) -> List[List[float]]:
+ return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, media_download_headers, infer)
+
+def _encode_without_cache(
+ model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
+ normalize_embeddings: bool, modality: Modality, media_download_headers: Optional[Dict], infer: bool) \
+ -> List[List[float]]:
try:
model = _available_models[model_cache_key][AvailableModelsKey.model]
encoder = get_encoder(model)
if isinstance(content, str):
- vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
+ vectorised = model.encode(
+ content, normalize=normalize_embeddings, modality=modality,
+ media_download_headers=media_download_headers, infer=infer
+ )
elif isinstance(content, (torch.Tensor, torch.FloatTensor)):
- vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
+ vectorised = model.encode(
+ content, normalize=normalize_embeddings, modality=modality,
+ media_download_headers=media_download_headers, infer=infer
+ )
else:
vector_batches = []
batch_size = _get_max_vectorise_batch_size()
@@ -133,9 +178,10 @@ def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], L
modality = infer_modality(batch[0] if isinstance(batch[0], (str, bytes)) else batch)
# TODO maybe the infer parameter can be replaced by modality
- infer = kwargs.pop('infer', False if modality == Modality.TEXT else True)
- encoded_batch = encoder.encode(batch, modality=modality, normalize=normalize_embeddings,
- infer=infer, **kwargs)
+ encoded_batch = encoder.encode(
+ batch, modality=modality, normalize=normalize_embeddings,
+ infer=infer, media_download_headers=media_download_headers
+ )
vector_batches.append(_convert_tensor_to_numpy(encoded_batch))
diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py
index 556ae4ed6..18d61776e 100644
--- a/src/marqo/tensor_search/api.py
+++ b/src/marqo/tensor_search/api.py
@@ -277,7 +277,7 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
reranker=search_query.reRanker,
filter=search_query.filter, device=device,
attributes_to_retrieve=search_query.attributesToRetrieve, boost=search_query.boost,
- image_download_headers=search_query.image_download_headers,
+ media_download_headers = search_query.mediaDownloadHeaders,
context=search_query.context,
score_modifiers=search_query.scoreModifiers,
model_auth=search_query.modelAuth,
diff --git a/src/marqo/tensor_search/models/api_models.py b/src/marqo/tensor_search/models/api_models.py
index 0c853e557..d688e55f0 100644
--- a/src/marqo/tensor_search/models/api_models.py
+++ b/src/marqo/tensor_search/models/api_models.py
@@ -7,7 +7,7 @@
from typing import Union, List, Dict, Optional
import pydantic
-from pydantic import BaseModel, root_validator, validator
+from pydantic import BaseModel, root_validator, validator, Field
from marqo.base_model import ImmutableStrictBaseModel
from marqo.core.models.hybrid_parameters import HybridParameters
@@ -47,7 +47,8 @@ class SearchQuery(BaseMarqoModel):
filter: str = None
attributesToRetrieve: Union[None, List[str]] = None
boost: Optional[Dict] = None
- image_download_headers: Optional[Dict] = None
+ imageDownloadHeaders: Optional[Dict] = Field(default_factory=None, alias="image_download_headers")
+ mediaDownloadHeaders: Optional[Dict] = None
context: Optional[SearchContext] = None
scoreModifiers: Optional[ScoreModifierLists] = None
modelAuth: Optional[ModelAuth] = None
@@ -68,6 +69,26 @@ def _preprocess_search_method(cls, value):
else:
return value
+ @root_validator(skip_on_failure=True)
+ def _validate_image_download_headers_and_media_download_headers(cls, values):
+ """Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
+
+ If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
+ rest of the code.
+
+ imageDownloadHeaders is deprecated and will be removed in the future.
+ """
+ image_download_headers = values.get('imageDownloadHeaders')
+ media_download_headers = values.get('mediaDownloadHeaders')
+ if image_download_headers and media_download_headers:
+ raise ValueError("Cannot set both imageDownloadHeaders(image_download_headers) and mediaDownloadHeaders. "
+ "The imageDownloadHeaders(image_download_headers) is deprecated and will be removed in the future. "
+ "Use mediaDownloadHeaders instead.")
+ if image_download_headers:
+ values['mediaDownloadHeaders'] = image_download_headers
+ return values
+
+
@root_validator(pre=False, skip_on_failure=True)
def validate_query_and_context(cls, values):
"""Validate that one of query and context are present for tensor/hybrid search, or just the query for lexical search.
diff --git a/src/marqo/tensor_search/models/search.py b/src/marqo/tensor_search/models/search.py
index 4d8c5a74c..0a7101eb8 100644
--- a/src/marqo/tensor_search/models/search.py
+++ b/src/marqo/tensor_search/models/search.py
@@ -4,7 +4,9 @@
from typing import Any, Union, List, Dict, Optional, NewType, Literal
from marqo.api.exceptions import InvalidArgError
+from marqo.core.models import MarqoQuery
from marqo.tensor_search.models.private_models import ModelAuth
+from marqo.s2_inference.multimodal_model_load import Modality
Qidx = NewType('Qidx', int) # Indicates the position of a search query in a bulk search request
JHash = NewType('JHash', int) # hash of a VectoriseJob. Used for quick access of VectorisedJobs
@@ -26,25 +28,25 @@ class VectorisedJobs(BaseModel):
content: List[Union[str, List[str]]]
device: str
normalize_embeddings: bool
- image_download_headers: Optional[Dict]
- content_type: Literal['text', 'media']
+ media_download_headers: Optional[Dict]
model_auth: Optional[ModelAuth]
+ modality: Modality
def __hash__(self):
return self.groupby_key() + hash(json.dumps(self.content, sort_keys=True))
def groupby_key(self) -> JHash:
return VectorisedJobs.get_groupby_key(self.model_name, self.model_properties, self.device,
- self.normalize_embeddings, self.content_type,
- self.image_download_headers)
+ self.normalize_embeddings, self.modality,
+ self.media_download_headers)
@staticmethod
def get_groupby_key(model_name: str, model_properties: Dict[str, Any], device: str,
- normalize_embeddings: bool, content_type: str, image_download_headers: Optional[Dict]) -> JHash:
+ normalize_embeddings: bool, modality: str, media_download_headers: Optional[Dict]) -> JHash:
return JHash(hash(model_name) + hash(json.dumps(model_properties, sort_keys=True))
+ hash(device) + hash(normalize_embeddings)
- + hash(content_type)
- + hash(json.dumps(image_download_headers, sort_keys=True))
+ + hash(modality)
+ + hash(json.dumps(media_download_headers, sort_keys=True))
)
def add_content(self, content: List[Union[str, List[str]]]) -> VectorisedJobPointer:
@@ -75,4 +77,29 @@ def __init__(self, **data):
def check_vector_length(cls, v):
if not (1 <= len(v) <= 64):
raise InvalidArgError('The number of tensors must be between 1 and 64')
- return v
\ No newline at end of file
+ return v
+
+
+class QueryContent(BaseModel):
+ content: str
+ modality: Modality
+
+
+class QueryContentCollector(BaseModel):
+ queries: List[QueryContent]
+ @property
+ def text_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.TEXT]
+
+ @property
+ def image_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.IMAGE]
+
+ @property
+ def video_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.VIDEO]
+
+ @property
+ def audio_queries(self) -> List[QueryContent]:
+ return [q for q in self.queries if q.modality == Modality.AUDIO]
+
\ No newline at end of file
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index a15646254..3addc0e62 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -90,7 +90,7 @@
from marqo.tensor_search.models.delete_docs_objects import MqDeleteDocsRequest
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.search import Qidx, JHash, SearchContext, VectorisedJobs, VectorisedJobPointer, \
- SearchContextTensor
+ SearchContextTensor, QueryContentCollector, QueryContent
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.vespa.exceptions import VespaStatusError
@@ -1465,7 +1465,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
reranker: Union[str, Dict] = None, filter: Optional[str] = None,
attributes_to_retrieve: Optional[List[str]] = None,
device: str = None, boost: Optional[Dict] = None,
- image_download_headers: Optional[Dict] = None,
+ media_download_headers: Optional[Dict] = None,
context: Optional[SearchContext] = None,
score_modifiers: Optional[ScoreModifierLists] = None,
model_auth: Optional[ModelAuth] = None,
@@ -1493,7 +1493,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
device: May be none, we calculate default device here
num_highlights: number of highlights to return for each doc
boost: boosters to re-weight the scores of individual fields
- image_download_headers: headers for downloading images
+ media_download_headers: headers to use when downloading media
context: a dictionary to allow custom vectors in search, for tensor search only
score_modifiers: a dictionary to modify the score based on field values, for tensor search only
model_auth: Authorisation details for downloading a model (if required)
@@ -1583,7 +1583,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
ef_search=ef_search, approximate=approximate, searchable_attributes=searchable_attributes,
filter_string=filter, device=selected_device, attributes_to_retrieve=attributes_to_retrieve,
boost=boost,
- image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers,
+ media_download_headers=media_download_headers, context=context, score_modifiers=score_modifiers,
model_auth=model_auth, highlights=highlights, text_query_prefix=text_query_prefix
)
elif search_method.upper() == SearchMethod.HYBRID:
@@ -1594,7 +1594,7 @@ def search(config: Config, index_name: str, text: Optional[Union[str, dict, Cust
ef_search=ef_search, approximate=approximate, searchable_attributes=searchable_attributes,
filter_string=filter, device=selected_device, attributes_to_retrieve=attributes_to_retrieve,
boost=boost,
- image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers,
+ media_download_headers=media_download_headers, context=context, score_modifiers=score_modifiers,
model_auth=model_auth, highlights=highlights, text_query_prefix=text_query_prefix,
hybrid_parameters=hybrid_parameters
)
@@ -1735,37 +1735,39 @@ def _lexical_search(
return gathered_docs
-def construct_vector_input_batches(query: Optional[Union[str, Dict]], index_info: MarqoIndex) \
- -> Tuple[List[str], List[str]]:
+def construct_vector_input_batches(query: Optional[Union[str, Dict]], media_download_headers: Optional[Dict] = None) \
+ -> QueryContentCollector:
"""Splits images from text in a single query (either a query string, or dict of weighted strings).
Args:
query: a string query, or a dict of weighted strings.
- index_info: used to determine whether URLs should be treated as images
+ media_download_headers: headers to use when downloading media
Returns:
- A tuple of string batches. The first is text content the second is image content.
+ A SearchQueryCollector object with the text and media content separated.
"""
# TODO - infer this from model
- treat_urls_as_media = True
-
+ query_content_list = []
if isinstance(query, str):
- if treat_urls_as_media and validate_url(query):
- return [], [query, ]
- else:
- return [query, ], []
+ query_content_list.append(
+ QueryContent(
+ content=query,
+ modality=infer_modality(query, media_download_headers=media_download_headers)
+ )
+ )
elif isinstance(query, dict): # is dict:
- ordered_queries = list(query.items())
- if treat_urls_as_media:
- text_queries = [k for k, _ in ordered_queries if not _is_image(k)]
- image_queries = [k for k, _ in ordered_queries if _is_image(k)]
- return text_queries, image_queries
- else:
- return [k for k, _ in ordered_queries], []
+ for query, weights in query.items():
+ query_content_list.append(
+ QueryContent(
+ content=query,
+ modality=infer_modality(query, media_download_headers=media_download_headers)
+ )
+ )
elif query is None:
- return [], []
+ pass
else:
raise ValueError(f"Incorrect type for query: {type(query).__name__}")
+ return QueryContentCollector(queries = query_content_list)
def gather_documents_from_response(response: QueryResult, marqo_index: MarqoIndex, highlights: bool,
@@ -1800,7 +1802,7 @@ def unstructured_index_attributes_to_retrieve(marqo_doc: Dict[str, Any], attribu
def assign_query_to_vector_job(
q: BulkSearchQueryEntity, jobs: Dict[JHash, VectorisedJobs],
- grouped_content: Tuple[List[str], List[str], List[str], List[str]],
+ grouped_content: QueryContentCollector,
index_info: MarqoIndex, device: str) -> List[VectorisedJobPointer]:
"""
For a individual query, assign its content (to be vectorised) to a vector job. If none exist with the correct
@@ -1819,34 +1821,39 @@ def assign_query_to_vector_job(
Returns:
A list of pointers to the location in a vector job that will have its vectorised content.
"""
- if len(grouped_content) != 2:
- raise RuntimeError(
- "assign_query_to_vector_job() expects param `grouped_content` with 2 elems. Instead received"
- f" `grouped_content` with {len(grouped_content)} elems")
ptrs = []
- for i, content in enumerate(grouped_content):
- content_type = ['text', 'media'][i]
- vector_job = VectorisedJobs(
- model_name=index_info.model.name,
- model_properties=index_info.model.get_properties(),
- content=content,
- device=device,
- normalize_embeddings=index_info.normalize_embeddings,
- image_download_headers=q.image_download_headers,
- content_type=content_type,
- model_auth=q.modelAuth
- )
- # If exists, add content to vector job. Otherwise create new
- if jobs.get(vector_job.groupby_key()) is not None:
- j = jobs.get(vector_job.groupby_key())
- ptrs.append(j.add_content(content))
- else:
- jobs[vector_job.groupby_key()] = vector_job
- ptrs.append(VectorisedJobPointer(
- job_hash=vector_job.groupby_key(),
- start_idx=0,
- end_idx=len(vector_job.content)
- ))
+ content_lists_by_modality = [
+ grouped_content.text_queries,
+ grouped_content.image_queries,
+ grouped_content.audio_queries,
+ grouped_content.video_queries,
+ ]
+
+ for i, list_of_queries_by_modalities in enumerate(content_lists_by_modality):
+ if len(list_of_queries_by_modalities) > 0:
+ content: List[str] = [query.content for query in list_of_queries_by_modalities]
+ modality: Modality = list_of_queries_by_modalities[0].modality
+ vector_job = VectorisedJobs(
+ model_name=index_info.model.name,
+ model_properties=index_info.model.get_properties(),
+ content=content,
+ device=device,
+ normalize_embeddings=index_info.normalize_embeddings,
+ media_download_headers=q.mediaDownloadHeaders,
+ model_auth=q.modelAuth,
+ modality = modality
+ )
+ # If exists, add content to vector job. Otherwise create new
+ if jobs.get(vector_job.groupby_key()) is not None:
+ j = jobs.get(vector_job.groupby_key())
+ ptrs.append(j.add_content(content))
+ else:
+ jobs[vector_job.groupby_key()] = vector_job
+ ptrs.append(VectorisedJobPointer(
+ job_hash=vector_job.groupby_key(),
+ start_idx=0,
+ end_idx=len(vector_job.content)
+ ))
return ptrs
@@ -1865,9 +1872,8 @@ def create_vector_jobs(queries: List[BulkSearchQueryEntity], config: Config, dev
qidx_to_job: Dict[Qidx, List[VectorisedJobPointer]] = dict()
jobs: Dict[JHash, VectorisedJobs] = {}
for i, q in enumerate(queries):
- q = queries[i]
# split images, from text:
- to_be_vectorised: Tuple[List[str], List[str]] = construct_vector_input_batches(q.q, q.index)
+ to_be_vectorised: QueryContentCollector = construct_vector_input_batches(q.q, q.mediaDownloadHeaders)
qidx_to_job[i] = assign_query_to_vector_job(q, jobs, to_be_vectorised, q.index, device)
return qidx_to_job, jobs
@@ -1882,12 +1888,13 @@ def vectorise_jobs(jobs: List[VectorisedJobs]) -> Dict[JHash, Dict[str, List[flo
# TODO: Handle exception for single job, and allow others to run.
try:
if v.content:
- modality = infer_modality(v.content[0] if isinstance(v.content, list) else v.content)
+ modality = infer_modality(v.content[0] if isinstance(v.content, list) else v.content,
+ media_download_headers=v.media_download_headers)
vectors = s2_inference.vectorise(
model_name=v.model_name, model_properties=v.model_properties,
content=v.content, device=v.device,
normalize_embeddings=v.normalize_embeddings,
- image_download_headers=v.image_download_headers,
+ media_download_headers=v.media_download_headers,
model_auth=v.model_auth,
enable_cache=True,
modality=modality
@@ -1940,11 +1947,13 @@ def get_query_vectors_from_jobs(
if ordered_queries:
# multiple queries. We have to weight and combine them:
vectorised_ordered_queries = [
- (get_content_vector(
+ (
+ get_content_vector(
possible_jobs=qidx_to_job[qidx],
jobs=jobs,
job_to_vectors=job_to_vectors,
- content=content),
+ content=content
+ ),
weight,
content
) for content, weight in ordered_queries
@@ -1999,15 +2008,12 @@ def get_content_vector(possible_jobs: List[VectorisedJobPointer], job_to_vectors
Raises runtime error if is not found
"""
- content_type = 'text' if infer_modality(content) == Modality.TEXT else 'media'
-
not_found_error = RuntimeError(f"get_content_vector(): could not find corresponding vector for content `{content}`")
for vec_job_pointer in possible_jobs:
- if jobs[vec_job_pointer.job_hash].content_type == content_type:
- try:
- return job_to_vectors[vec_job_pointer.job_hash][content]
- except KeyError:
- raise not_found_error
+ try:
+ return job_to_vectors[vec_job_pointer.job_hash][content]
+ except KeyError:
+ raise not_found_error
raise not_found_error
@@ -2019,19 +2025,20 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear
if q.q is None:
prefixed_q = q.q
elif isinstance(q.q, str):
- if _is_image(q.q):
- prefixed_q = q.q
- else:
+ modality = infer_modality(q.q, q.mediaDownloadHeaders)
+ if modality == Modality.TEXT:
prefixed_q = f"{text_query_prefix}{q.q}"
+ else:
+ prefixed_q = q.q
else: # q.q is dict
prefixed_q = {}
for key, value in q.q.items():
# Apply prefix if key is not an image or if index does not treat URLs and pointers as images
- if _is_image(key):
- prefixed_q[key] = value
+ modality = infer_modality(key, q.mediaDownloadHeaders)
+ if modality == Modality.TEXT:
+ prefixed_q[key] = f"{text_query_prefix}{value}"
else:
- prefixed_q[f"{text_query_prefix}{key}"] = value
-
+ prefixed_q[key] = value
new_query_object = BulkSearchQueryEntity(
q=prefixed_q,
searchableAttributes=q.searchableAttributes,
@@ -2042,7 +2049,7 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear
filter=q.filter,
attributesToRetrieve=q.attributesToRetrieve,
boost=q.boost,
- image_download_headers=q.image_download_headers,
+ mediaDownloadHeaders=q.mediaDownloadHeaders,
context=q.context,
scoreModifiers=q.scoreModifiers,
index=q.index,
@@ -2087,7 +2094,7 @@ def _vector_text_search(
ef_search: Optional[int] = None, approximate: bool = True,
searchable_attributes: Iterable[str] = None, filter_string: str = None, device: str = None,
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None,
- image_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
+ media_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
score_modifiers: Optional[ScoreModifierLists] = None, model_auth: Optional[ModelAuth] = None,
highlights: bool = False, text_query_prefix: Optional[str] = None) -> Dict:
"""
@@ -2104,7 +2111,7 @@ def _vector_text_search(
verbose: if 0 - nothing is printed. if 1 - data is printed without vectors, if 2 - full
objects are printed out
attributes_to_retrieve: if set, only returns these fields
- image_download_headers: headers for downloading images
+ media_download_headers: headers for downloading media
context: a dictionary to allow custom vectors in search
score_modifiers: a dictionary to modify the score based on field values, for tensor search only
model_auth: Authorisation details for downloading a model (if required)
@@ -2153,7 +2160,7 @@ def _vector_text_search(
queries = [BulkSearchQueryEntity(
q=query, searchableAttributes=searchable_attributes, searchMethod=SearchMethod.TENSOR, limit=result_count,
offset=offset, showHighlights=False, filter=filter_string, attributesToRetrieve=attributes_to_retrieve,
- boost=boost, image_download_headers=image_download_headers, context=context, scoreModifiers=score_modifiers,
+ boost=boost, mediaDownloadHeaders=media_download_headers, context=context, scoreModifiers=score_modifiers,
index=marqo_index, modelAuth=model_auth, text_query_prefix=text_query_prefix
)]
From 574e61d3642a24a6f5e00cee32af7bf234d75aeb Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 10:34:33 +1100
Subject: [PATCH 03/29] Finish development
---
src/marqo/api/models/embed_request.py | 28 +++++++++++++++++--
src/marqo/core/embed/embed.py | 13 +++++----
.../embedding_models/abstract_clip_model.py | 2 +-
src/marqo/s2_inference/clip_utils.py | 2 +-
.../s2_inference/multimodal_model_load.py | 20 ++++++-------
src/marqo/s2_inference/onnx_clip_utils.py | 2 +-
src/marqo/tensor_search/api.py | 2 +-
7 files changed, 46 insertions(+), 23 deletions(-)
diff --git a/src/marqo/api/models/embed_request.py b/src/marqo/api/models/embed_request.py
index 9ca47422e..ff16f6a3a 100644
--- a/src/marqo/api/models/embed_request.py
+++ b/src/marqo/api/models/embed_request.py
@@ -6,6 +6,8 @@
import pydantic
from typing import Union, List, Dict, Optional, Any
+from pydantic import Field, root_validator
+
from marqo.tensor_search.models.private_models import ModelAuth
from marqo.tensor_search.models.api_models import BaseMarqoModel
from marqo.core.embed.embed import EmbedContentType
@@ -15,9 +17,10 @@
class EmbedRequest(BaseMarqoModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
- image_download_headers: Optional[Dict] = None
+ image_download_headers: Optional[Dict] = Field(default=None, alias="imageDownloadHeaders")
+ mediaDownloadHeaders: Optional[Dict] = Field(default=None, alias="mediaDownloadHeaders")
modelAuth: Optional[ModelAuth] = None
- content_type: Optional[EmbedContentType] = EmbedContentType.Query
+ content_type: Optional[EmbedContentType] = Field(EmbedContentType.Query, alias=("contentType"))
@pydantic.validator('content')
def validate_content(cls, value):
@@ -47,4 +50,23 @@ def validate_content(cls, value):
else:
raise ValueError("Embed content should be a string, a dictionary, or a list of strings or dictionaries")
- return value
\ No newline at end of file
+ return value
+
+ @root_validator(skip_on_failure=True)
+ def _validate_image_download_headers_and_media_download_headers(cls, values):
+ """Validate imageDownloadHeaders and mediaDownloadHeaders. Raise an error if both are set.
+
+ If imageDownloadHeaders is set, set mediaDownloadHeaders to it and use mediaDownloadHeaders in the
+ rest of the code.
+
+ imageDownloadHeaders is deprecated and will be removed in the future.
+ """
+ image_download_headers = values.get('imageDownloadHeaders')
+ media_download_headers = values.get('mediaDownloadHeaders')
+ if image_download_headers and media_download_headers:
+ raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
+ "The imageDownloadHeaders is deprecated and will be removed in the future. "
+ "Use mediaDownloadHeaders instead.")
+ if image_download_headers:
+ values['mediaDownloadHeaders'] = image_download_headers
+ return values
\ No newline at end of file
diff --git a/src/marqo/core/embed/embed.py b/src/marqo/core/embed/embed.py
index 29d6fcf54..4730ddcd5 100644
--- a/src/marqo/core/embed/embed.py
+++ b/src/marqo/core/embed/embed.py
@@ -34,11 +34,12 @@ def validate_default_device(cls, value):
return value
def embed_content(
- self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
- index_name: str, device: str = None, image_download_headers: Optional[Dict] = None,
- model_auth: Optional[ModelAuth] = None,
- content_type: Optional[EmbedContentType] = EmbedContentType.Query
- ) -> Dict:
+ self, content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]],
+ index_name: str, device: str = None,
+ media_download_headers: Optional[Dict] = None,
+ model_auth: Optional[ModelAuth] = None,
+ content_type: Optional[EmbedContentType] = EmbedContentType.Query
+ ) -> Dict:
"""
Use the index's model to embed the content
@@ -105,7 +106,7 @@ def embed_content(
BulkSearchQueryEntity(
q=content_entry,
index=marqo_index,
- image_download_headers=image_download_headers,
+ mediaDownloadHeaders=media_download_headers,
modelAuth=model_auth,
text_query_prefix=prefix
# TODO: Check if it's fine that we leave out the other parameters
diff --git a/src/marqo/core/inference/embedding_models/abstract_clip_model.py b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
index 1b2a33b23..b89728a5d 100644
--- a/src/marqo/core/inference/embedding_models/abstract_clip_model.py
+++ b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
@@ -68,7 +68,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("image_download_headers", dict())
+ image_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
else:
logger.debug('text')
diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py
index 342e6d849..f4f7acde1 100644
--- a/src/marqo/s2_inference/clip_utils.py
+++ b/src/marqo/s2_inference/clip_utils.py
@@ -485,7 +485,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("image_download_headers", dict())
+ image_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
else:
logger.debug('text')
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 0c7ed9431..e57f74c6a 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -110,15 +110,15 @@ def preprocessor(self, modality):
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
return self.encoder.preprocessor(modality)
- def encode(self, content, modality, **kwargs):
+ def encode(self, content, modality, media_download_headers: Optional[Dict]=None, **kwargs):
if self.encoder is None:
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
- return self.encoder.encode(content, modality, **kwargs)
+ return self.encoder.encode(content, modality, media_download_headers, **kwargs)
class ModelEncoder(ABC):
@abstractmethod
- def encode(self, content, modality, **kwargs):
+ def encode(self, content, modality, media_download_headers, **kwargs):
pass
@@ -126,8 +126,8 @@ class DefaultEncoder(ModelEncoder):
def __init__(self, model):
self.model = model
- def encode(self, content, modality, **kwargs):
- return self.model.encode(content, **kwargs)
+ def encode(self, content, modality, media_download_headers, **kwargs):
+ return self.model.encode(content, media_download_headers, **kwargs)
@contextmanager
@@ -251,7 +251,7 @@ def preprocessor(self, modality):
return self._preprocessors.get(modality)
- def encode(self, content, modality, normalize=True, **kwargs):
+ def encode(self, content, modality, normalize=True, media_download_headers: Optional[Dict]=None, **kwargs):
inputs = {}
if modality == Modality.TEXT:
@@ -269,7 +269,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
with open(temp_filename, 'wb') as f:
f.write(content)
elif isinstance(content, str) and "http" in content:
- self._download_content(content, temp_filename)
+ self._download_content(content, temp_filename, media_download_headers)
else:
return self.encode([content], modality=Modality.TEXT)
@@ -280,7 +280,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
if isinstance(content, str) and "http" in content:
suffix = ".mp4" if modality == Modality.VIDEO else ".wav"
with self._temp_file(suffix) as temp_filename:
- self._download_content(content, temp_filename)
+ self._download_content(content, temp_filename, media_download_headers)
preprocessed_content = self.preprocessor(modality)([temp_filename], return_tensors='pt')
inputs[modality.value] = to_device(preprocessed_content, self.model.device)['pixel_values']
@@ -302,11 +302,11 @@ def encode(self, content, modality, normalize=True, **kwargs):
return embeddings.cpu().numpy()
- def _download_content(self, url, filename):
+ def _download_content(self, url, filename, media_download_headers: Optional[Dict]=None):
# 3 seconds for images, 20 seconds for audio and video
timeout_ms = 3000 if filename.endswith(('.png', '.jpg', '.jpeg')) else 20000
- buffer = download_image_from_url(url, {}, timeout_ms)
+ buffer = download_image_from_url(url, media_download_headers, timeout_ms)
with open(filename, 'wb') as f:
f.write(buffer.getvalue())
diff --git a/src/marqo/s2_inference/onnx_clip_utils.py b/src/marqo/s2_inference/onnx_clip_utils.py
index 31da79185..a9a6ee338 100644
--- a/src/marqo/s2_inference/onnx_clip_utils.py
+++ b/src/marqo/s2_inference/onnx_clip_utils.py
@@ -167,7 +167,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
raise ValueError(f"expected default='image' or default='text' but received {default}")
if is_image:
- logger.debug('image')
+ logger.debug('image'),
return self.encode_image(inputs, normalize=True)
else:
logger.debug('text')
diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py
index 18d61776e..8386b8382 100644
--- a/src/marqo/tensor_search/api.py
+++ b/src/marqo/tensor_search/api.py
@@ -334,7 +334,7 @@ def embed(embedding_request: EmbedRequest, index_name: str, device: str = Depend
return marqo_config.embed.embed_content(
content=embedding_request.content,
index_name=index_name, device=device,
- image_download_headers=embedding_request.image_download_headers,
+ media_download_headers=embedding_request.mediaDownloadHeaders,
model_auth=embedding_request.modelAuth,
content_type=embedding_request.content_type
)
From dfbe81c5062f7524f9ffeb832db429610216efbd Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 11:04:42 +1100
Subject: [PATCH 04/29] Fix more than 2 modalities bugs in search
---
src/marqo/tensor_search/tensor_search.py | 22 +++++++++++-----------
1 file changed, 11 insertions(+), 11 deletions(-)
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index 3addc0e62..0aaaf61ed 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -1888,8 +1888,10 @@ def vectorise_jobs(jobs: List[VectorisedJobs]) -> Dict[JHash, Dict[str, List[flo
# TODO: Handle exception for single job, and allow others to run.
try:
if v.content:
- modality = infer_modality(v.content[0] if isinstance(v.content, list) else v.content,
- media_download_headers=v.media_download_headers)
+ modality = infer_modality(
+ v.content[0] if isinstance(v.content, list) else v.content,
+ media_download_headers=v.media_download_headers
+ )
vectors = s2_inference.vectorise(
model_name=v.model_name, model_properties=v.model_properties,
content=v.content, device=v.device,
@@ -1950,7 +1952,6 @@ def get_query_vectors_from_jobs(
(
get_content_vector(
possible_jobs=qidx_to_job[qidx],
- jobs=jobs,
job_to_vectors=job_to_vectors,
content=content
),
@@ -1984,7 +1985,6 @@ def get_query_vectors_from_jobs(
# result[qidx] = vectors[0]
result[qidx] = get_content_vector(
possible_jobs=qidx_to_job.get(qidx, []),
- jobs=jobs,
job_to_vectors=job_to_vectors,
content=q.q
)
@@ -1993,14 +1993,16 @@ def get_query_vectors_from_jobs(
return result
-def get_content_vector(possible_jobs: List[VectorisedJobPointer], job_to_vectors: Dict[JHash, Dict[str, List[float]]],
- jobs: Dict[JHash, VectorisedJobs], content: str) -> List[float]:
+def get_content_vector(
+ possible_jobs: List[VectorisedJobPointer],
+ job_to_vectors: Dict[JHash, Dict[str, List[float]]],
+ content: str
+) -> List[float]:
"""finds the vector associated with a piece of content
Args:
possible_jobs: The jobs where the target vector may reside
- treat_urls_as_media: an index_parameter that indicates whether content should be treated as image, audio, video
- if it has a URL structure
+ job_to_vectors: The mapping of job to vectors
content: The content to search
Returns:
@@ -2010,10 +2012,8 @@ def get_content_vector(possible_jobs: List[VectorisedJobPointer], job_to_vectors
"""
not_found_error = RuntimeError(f"get_content_vector(): could not find corresponding vector for content `{content}`")
for vec_job_pointer in possible_jobs:
- try:
+ if content in job_to_vectors[vec_job_pointer.job_hash]:
return job_to_vectors[vec_job_pointer.job_hash][content]
- except KeyError:
- raise not_found_error
raise not_found_error
From 3cc7a2f1d596f02e92b9063f904ecde6951821b1 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 12:32:51 +1100
Subject: [PATCH 05/29] Need to fix infer issue
---
.../embedding_models/abstract_clip_model.py | 4 +--
src/marqo/s2_inference/clip_utils.py | 7 ++--
.../s2_inference/multimodal_model_load.py | 2 +-
src/marqo/s2_inference/s2_inference.py | 2 +-
src/marqo/tensor_search/add_docs.py | 32 +++++++++----------
src/marqo/tensor_search/tensor_search.py | 4 +--
6 files changed, 25 insertions(+), 26 deletions(-)
diff --git a/src/marqo/core/inference/embedding_models/abstract_clip_model.py b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
index b89728a5d..43eb3a849 100644
--- a/src/marqo/core/inference/embedding_models/abstract_clip_model.py
+++ b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
@@ -53,8 +53,8 @@ def encode_text(self, inputs: Union[str, List[str]], normalize: bool = True) ->
def encode_image(self, inputs, normalize: bool = True, image_download_headers: dict = None) -> np.ndarray:
pass
- def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
- default: str = 'text', normalize=True, **kwargs) -> np.ndarray:
+ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> np.ndarray:
+ default = "text"
infer = kwargs.pop('infer', True)
if infer and _is_image(inputs):
is_image = True
diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py
index f4f7acde1..200b795ab 100644
--- a/src/marqo/s2_inference/clip_utils.py
+++ b/src/marqo/s2_inference/clip_utils.py
@@ -177,6 +177,8 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
+ if image_download_headers is None:
+ image_download_headers = dict()
headers.update(image_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
@@ -467,9 +469,8 @@ def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType,
assert outputs.shape == _shape_before
return self._convert_output(outputs)
- def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
- default: str = 'text', normalize=True, **kwargs) -> FloatTensor:
-
+ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> FloatTensor:
+ default = "text"
infer = kwargs.pop('infer', True)
if infer and _is_image(inputs):
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index e57f74c6a..61e4992c6 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -127,7 +127,7 @@ def __init__(self, model):
self.model = model
def encode(self, content, modality, media_download_headers, **kwargs):
- return self.model.encode(content, media_download_headers, **kwargs)
+ return self.model.encode(content, modality=modality, media_download_headers=media_download_headers, **kwargs)
@contextmanager
diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py
index 874565efa..a5d92fb34 100644
--- a/src/marqo/s2_inference/s2_inference.py
+++ b/src/marqo/s2_inference/s2_inference.py
@@ -53,7 +53,7 @@ def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[
enable_cache: bool = False,
modality: Modality = Modality.TEXT,
media_download_headers: Optional[Dict] = None,
- infer: bool = False
+ infer: bool = True
) -> List[List[float]]:
"""Vectorise the given content using the given model.
diff --git a/src/marqo/tensor_search/add_docs.py b/src/marqo/tensor_search/add_docs.py
index c87a7c78f..1643c432d 100644
--- a/src/marqo/tensor_search/add_docs.py
+++ b/src/marqo/tensor_search/add_docs.py
@@ -39,7 +39,6 @@
def threaded_download_and_preprocess_content(allocated_docs: List[dict],
media_repo: dict,
tensor_fields: List[str],
- image_download_headers: dict,
device: str = None,
media_field_types_mapping: Optional[Dict[str, FieldType]] = None,
media_download_headers: Optional[Dict] = None,
@@ -118,7 +117,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
try:
- media_repo[doc[field]] = clip_utils.load_image_from_path(doc[field], image_download_headers,
+ media_repo[doc[field]] = clip_utils.load_image_from_path(doc[field], media_download_headers,
timeout_ms=int(
TIMEOUT_SECONDS * 1000),
metrics_obj=metric_obj)
@@ -166,7 +165,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
try:
- processed_chunks = download_and_chunk_media(doc[field], device, download_headers, inferred_modality,
+ processed_chunks = download_and_chunk_media(doc[field], device, media_download_headers, inferred_modality,
marqo_index_type, marqo_index_model, preprocessors,
audio_preprocessing, video_preprocessing)
media_repo[doc[field]] = processed_chunks
@@ -188,7 +187,7 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
try:
media_repo[sub_field] = clip_utils.load_image_from_path(
sub_field,
- image_download_headers,
+ media_download_headers,
timeout=TIMEOUT_SECONDS,
metrics_obj=metric_obj
)
@@ -289,11 +288,10 @@ def _determine_thread_count(marqo_index: MarqoIndex, add_docs_params: AddDocsPar
@contextmanager
def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_fields: List[str],
- image_download_headers: dict,
model_name: str,
normalize_embeddings: bool,
media_field_types_mapping: Optional[Dict[str, FieldType]],
- media_download_headers: Optional[Dict] = None, # Optional for now
+ media_download_headers: Optional[Dict] = None,
model_properties: Optional[Dict] = None,
model_auth: Optional[ModelAuth] = None,
device: Optional[str] = None,
@@ -309,7 +307,6 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
docs = docs,
thread_count = thread_count,
tensor_fields = tensor_fields,
- image_download_headers = image_download_headers,
model_name = model_name,
normalize_embeddings = normalize_embeddings,
force_download = force_download,
@@ -336,15 +333,17 @@ def download_and_preprocess_content(docs: List[dict], thread_count: int, tensor_
pass
-def process_batch(docs: List[dict], thread_count: int, tensor_fields: List[str],
- image_download_headers: dict, model_name: str, normalize_embeddings: bool,
- force_download: bool, media_field_types_mapping: Optional[Dict[str, FieldType]],
- model_properties: Optional[Dict],
- model_auth: Optional[ModelAuth], device: Optional[str],
- patch_method_exists: bool, marqo_index_type: Optional[IndexType], marqo_index_model: Optional[Model],
- media_download_headers: Optional[Dict] = None,
- audio_preprocessing: Optional[AudioPreProcessing] = None,
- video_preprocessing: Optional[VideoPreProcessing] = None) -> dict:
+def process_batch(
+ docs: List[dict], thread_count: int, tensor_fields: List[str],
+ model_name: str, normalize_embeddings: bool,
+ force_download: bool, media_field_types_mapping: Optional[Dict[str, FieldType]],
+ model_properties: Optional[Dict],
+ model_auth: Optional[ModelAuth], device: Optional[str],
+ patch_method_exists: bool, marqo_index_type: Optional[IndexType], marqo_index_model: Optional[Model],
+ media_download_headers: Optional[Dict] = None,
+ audio_preprocessing: Optional[AudioPreProcessing] = None,
+ video_preprocessing: Optional[VideoPreProcessing] = None
+) -> dict:
docs_per_thread = math.ceil(len(docs) / thread_count)
copied = copy.deepcopy(docs)
@@ -373,7 +372,6 @@ def process_batch(docs: List[dict], thread_count: int, tensor_fields: List[str],
allocation,
media_repo,
tensor_fields,
- image_download_headers,
device,
media_field_types_mapping,
media_download_headers,
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index 0aaaf61ed..654f3a292 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -2372,7 +2372,7 @@ def vectorise_multimodal_combination_field_unstructured(field: str,
model_name=marqo_index.model.name,
model_properties=marqo_index.model.properties, content=prefixed_text_content_to_vectorise,
device=device, normalize_embeddings=normalize_embeddings,
- infer=False, model_auth=model_auth, modality=Modality.TEXT
+ infer=True, model_auth=model_auth, modality=Modality.TEXT
)
vectors_list.extend(text_vectors)
@@ -2596,7 +2596,7 @@ def vectorise_multimodal_combination_field_structured(
content=prefixed_text_content,
device=device,
normalize_embeddings=normalize_embeddings,
- infer=False,
+ infer=True,
model_auth=model_auth,
modality=Modality.TEXT
)
From aa2b1d64bdc34fc6e4a3f00401fcf112cf07f961 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 12:49:13 +1100
Subject: [PATCH 06/29] Revert changes in src/marqo/s2_inference/ and
reconsider parameters passing
---
src/marqo/s2_inference/clip_utils.py | 9 +-
.../s2_inference/multimodal_model_load.py | 32 ++++---
src/marqo/s2_inference/onnx_clip_utils.py | 2 +-
src/marqo/s2_inference/s2_inference.py | 84 +++++--------------
4 files changed, 39 insertions(+), 88 deletions(-)
diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py
index 200b795ab..342e6d849 100644
--- a/src/marqo/s2_inference/clip_utils.py
+++ b/src/marqo/s2_inference/clip_utils.py
@@ -177,8 +177,6 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
- if image_download_headers is None:
- image_download_headers = dict()
headers.update(image_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
@@ -469,8 +467,9 @@ def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType,
assert outputs.shape == _shape_before
return self._convert_output(outputs)
- def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> FloatTensor:
- default = "text"
+ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
+ default: str = 'text', normalize=True, **kwargs) -> FloatTensor:
+
infer = kwargs.pop('infer', True)
if infer and _is_image(inputs):
@@ -486,7 +485,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], nor
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("media_download_headers", dict())
+ image_download_headers = kwargs.get("image_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
else:
logger.debug('text')
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 61e4992c6..173630c22 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -11,11 +11,10 @@
from pydantic import BaseModel
from enum import Enum
from abc import ABC, abstractmethod
-from typing import List, Dict, Any, Union, Optional
+from typing import List, Dict, Any, Union
from PIL.Image import Image
import torch
from urllib.parse import quote
-from marqo.core.inference.image_download import DEFAULT_HEADERS
from marqo.s2_inference.multimodal_model_load import *
@@ -110,15 +109,15 @@ def preprocessor(self, modality):
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
return self.encoder.preprocessor(modality)
- def encode(self, content, modality, media_download_headers: Optional[Dict]=None, **kwargs):
+ def encode(self, content, modality, **kwargs):
if self.encoder is None:
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
- return self.encoder.encode(content, modality, media_download_headers, **kwargs)
+ return self.encoder.encode(content, modality, **kwargs)
class ModelEncoder(ABC):
@abstractmethod
- def encode(self, content, modality, media_download_headers, **kwargs):
+ def encode(self, content, modality, **kwargs):
pass
@@ -126,14 +125,13 @@ class DefaultEncoder(ModelEncoder):
def __init__(self, model):
self.model = model
- def encode(self, content, modality, media_download_headers, **kwargs):
- return self.model.encode(content, modality=modality, media_download_headers=media_download_headers, **kwargs)
+ def encode(self, content, modality, **kwargs):
+ return self.model.encode(content, **kwargs)
@contextmanager
-def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sample_size=10240): # 10 KB
- # It's ok to pass None to requests.get() for headers and it won't change the default headers
- response = requests.get(url, stream=True, headers=media_download_headers)
+def fetch_content_sample(url, sample_size=10240): # 10 KB
+ response = requests.get(url, stream=True)
buffer = io.BytesIO()
try:
for chunk in response.iter_content(chunk_size=min(sample_size, 8192)):
@@ -147,7 +145,7 @@ def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sam
response.close()
-def infer_modality(content: Union[str, List[str], bytes], media_download_headers: Optional[dict] = None) -> Modality:
+def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
"""
Infer the modality of the content. Video, audio, image or text.
"""
@@ -169,7 +167,7 @@ def infer_modality(content: Union[str, List[str], bytes], media_download_headers
if validate_url(encoded_url):
# Use context manager to handle content sample
try:
- with fetch_content_sample(encoded_url, media_download_headers) as sample:
+ with fetch_content_sample(encoded_url) as sample:
mime = magic.from_buffer(sample.read(), mime=True)
if mime.startswith('image/'):
return Modality.IMAGE
@@ -251,7 +249,7 @@ def preprocessor(self, modality):
return self._preprocessors.get(modality)
- def encode(self, content, modality, normalize=True, media_download_headers: Optional[Dict]=None, **kwargs):
+ def encode(self, content, modality, normalize=True, **kwargs):
inputs = {}
if modality == Modality.TEXT:
@@ -269,7 +267,7 @@ def encode(self, content, modality, normalize=True, media_download_headers: Opti
with open(temp_filename, 'wb') as f:
f.write(content)
elif isinstance(content, str) and "http" in content:
- self._download_content(content, temp_filename, media_download_headers)
+ self._download_content(content, temp_filename)
else:
return self.encode([content], modality=Modality.TEXT)
@@ -280,7 +278,7 @@ def encode(self, content, modality, normalize=True, media_download_headers: Opti
if isinstance(content, str) and "http" in content:
suffix = ".mp4" if modality == Modality.VIDEO else ".wav"
with self._temp_file(suffix) as temp_filename:
- self._download_content(content, temp_filename, media_download_headers)
+ self._download_content(content, temp_filename)
preprocessed_content = self.preprocessor(modality)([temp_filename], return_tensors='pt')
inputs[modality.value] = to_device(preprocessed_content, self.model.device)['pixel_values']
@@ -302,11 +300,11 @@ def encode(self, content, modality, normalize=True, media_download_headers: Opti
return embeddings.cpu().numpy()
- def _download_content(self, url, filename, media_download_headers: Optional[Dict]=None):
+ def _download_content(self, url, filename):
# 3 seconds for images, 20 seconds for audio and video
timeout_ms = 3000 if filename.endswith(('.png', '.jpg', '.jpeg')) else 20000
- buffer = download_image_from_url(url, media_download_headers, timeout_ms)
+ buffer = download_image_from_url(url, {}, timeout_ms)
with open(filename, 'wb') as f:
f.write(buffer.getvalue())
diff --git a/src/marqo/s2_inference/onnx_clip_utils.py b/src/marqo/s2_inference/onnx_clip_utils.py
index a9a6ee338..31da79185 100644
--- a/src/marqo/s2_inference/onnx_clip_utils.py
+++ b/src/marqo/s2_inference/onnx_clip_utils.py
@@ -167,7 +167,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
raise ValueError(f"expected default='image' or default='text' but received {default}")
if is_image:
- logger.debug('image'),
+ logger.debug('image')
return self.encode_image(inputs, normalize=True)
else:
logger.debug('text')
diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py
index a5d92fb34..fc97d5300 100644
--- a/src/marqo/s2_inference/s2_inference.py
+++ b/src/marqo/s2_inference/s2_inference.py
@@ -47,28 +47,8 @@
def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[bytes]],
model_properties: dict = None,
- device: str = None,
- normalize_embeddings: bool = get_default_normalization(),
- model_auth: ModelAuth = None,
- enable_cache: bool = False,
- modality: Modality = Modality.TEXT,
- media_download_headers: Optional[Dict] = None,
- infer: bool = True
- ) -> List[List[float]]:
- """Vectorise the given content using the given model.
-
- Args:
- model_name: The name of the model to use.
- content: The content to vectorise.
- model_properties: The properties of the model to use.
- device: The device to use.
- normalize_embeddings: Whether to normalize the embeddings.
- model_auth: The model authorisation details.
- enable_cache: Whether to enable the inference cache.
- modality: The modality of the content.
- media_download_headers: The media download headers.
- infer: Whether to infer the modality. Deprecated and should be replaced by modality.
- """
+ device: str = None, normalize_embeddings: bool = get_default_normalization(),
+ model_auth: ModelAuth = None, enable_cache: bool = False, modality: Modality = Modality.TEXT, **kwargs,) -> List[List[float]]:
if not device:
raise InternalError(message=f"vectorise (internal function) cannot be called without setting device!")
@@ -83,37 +63,25 @@ def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[
model = _available_models[model_cache_key][AvailableModelsKey.model]
if _marqo_inference_cache.is_enabled() and enable_cache:
- return _vectorise_with_cache(
- model, model_cache_key, content, normalize_embeddings, modality,
- media_download_headers, infer
- )
+ return _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
else:
- return _vectorise_without_cache(
- model_cache_key, content, normalize_embeddings, modality,
- media_download_headers, infer
- )
+ return _vectorise_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
-def _vectorise_with_cache(model, model_cache_key: str, content, normalize_embeddings: bool, modality: Modality,
- media_download_headers: Optional[Dict], infer: bool):
+def _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
if isinstance(content, str):
vectorised = _marqo_inference_cache.get(model_cache_key, content)
if vectorised is None:
- vectorised = _encode_without_cache(
- model_cache_key, content, normalize_embeddings, modality, media_download_headers, infer
- )
+ vectorised = _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
_marqo_inference_cache.set(model_cache_key, content, vectorised[0])
else:
vectorised = _convert_cached_embeddings_to_output(vectorised)
return vectorised
elif isinstance(content, list):
- return _vectorise_list_with_cache(
- model, model_cache_key, content, normalize_embeddings, modality, media_download_headers, infer
- )
+ return _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
else:
raise TypeError(f"Unsupported content type: {type(content).__name__}")
-def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality,
- media_download_headers, infer):
+def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
contents_to_vectorise = []
cached_output = []
@@ -129,8 +97,7 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
contents_to_vectorise.append(content_item)
if contents_to_vectorise:
- vectorised_outputs = _encode_without_cache(
- model_cache_key, contents_to_vectorise, normalize_embeddings, modality, media_download_headers, infer)
+ vectorised_outputs = _encode_without_cache(model_cache_key, contents_to_vectorise, normalize_embeddings, modality, **kwargs)
# Cache the vectorised outputs
for content_item, vectorised_output in zip(contents_to_vectorise, vectorised_outputs):
if isinstance(content_item, str):
@@ -143,32 +110,20 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
return vectorised_outputs
+def _vectorise_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
+ normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
+ return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
-def _vectorise_without_cache(
- model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality,
- media_download_headers: Optional[Dict], infer: bool
-) -> List[List[float]]:
- return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, media_download_headers, infer)
-
-def _encode_without_cache(
- model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, media_download_headers: Optional[Dict], infer: bool) \
- -> List[List[float]]:
+def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
+ normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
try:
model = _available_models[model_cache_key][AvailableModelsKey.model]
encoder = get_encoder(model)
if isinstance(content, str):
- vectorised = model.encode(
- content, normalize=normalize_embeddings, modality=modality,
- media_download_headers=media_download_headers, infer=infer
- )
+ vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
elif isinstance(content, (torch.Tensor, torch.FloatTensor)):
- vectorised = model.encode(
- content, normalize=normalize_embeddings, modality=modality,
- media_download_headers=media_download_headers, infer=infer
- )
+ vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
else:
vector_batches = []
batch_size = _get_max_vectorise_batch_size()
@@ -178,10 +133,9 @@ def _encode_without_cache(
modality = infer_modality(batch[0] if isinstance(batch[0], (str, bytes)) else batch)
# TODO maybe the infer parameter can be replaced by modality
- encoded_batch = encoder.encode(
- batch, modality=modality, normalize=normalize_embeddings,
- infer=infer, media_download_headers=media_download_headers
- )
+ infer = kwargs.pop('infer', False if modality == Modality.TEXT else True)
+ encoded_batch = encoder.encode(batch, modality=modality, normalize=normalize_embeddings,
+ infer=infer, **kwargs)
vector_batches.append(_convert_tensor_to_numpy(encoded_batch))
From 2dbdb009d895a6b759648a75e86fc1e83c5095af Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 14:13:29 +1100
Subject: [PATCH 07/29] Fix tests
---
src/marqo/s2_inference/clip_utils.py | 4 +-
.../s2_inference/multimodal_model_load.py | 32 +++++-----
src/marqo/s2_inference/s2_inference.py | 58 +++++++++++++------
3 files changed, 61 insertions(+), 33 deletions(-)
diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py
index 342e6d849..ff787dd08 100644
--- a/src/marqo/s2_inference/clip_utils.py
+++ b/src/marqo/s2_inference/clip_utils.py
@@ -177,6 +177,8 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
+ if image_download_headers is None:
+ image_download_headers = dict()
headers.update(image_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
@@ -485,7 +487,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("image_download_headers", dict())
+ image_download_headers = kwargs.get("media_download_headers", dict())
return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
else:
logger.debug('text')
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 173630c22..61e4992c6 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -11,10 +11,11 @@
from pydantic import BaseModel
from enum import Enum
from abc import ABC, abstractmethod
-from typing import List, Dict, Any, Union
+from typing import List, Dict, Any, Union, Optional
from PIL.Image import Image
import torch
from urllib.parse import quote
+from marqo.core.inference.image_download import DEFAULT_HEADERS
from marqo.s2_inference.multimodal_model_load import *
@@ -109,15 +110,15 @@ def preprocessor(self, modality):
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
return self.encoder.preprocessor(modality)
- def encode(self, content, modality, **kwargs):
+ def encode(self, content, modality, media_download_headers: Optional[Dict]=None, **kwargs):
if self.encoder is None:
raise ValueError("Model has not been loaded yet. Call _load_model() first.")
- return self.encoder.encode(content, modality, **kwargs)
+ return self.encoder.encode(content, modality, media_download_headers, **kwargs)
class ModelEncoder(ABC):
@abstractmethod
- def encode(self, content, modality, **kwargs):
+ def encode(self, content, modality, media_download_headers, **kwargs):
pass
@@ -125,13 +126,14 @@ class DefaultEncoder(ModelEncoder):
def __init__(self, model):
self.model = model
- def encode(self, content, modality, **kwargs):
- return self.model.encode(content, **kwargs)
+ def encode(self, content, modality, media_download_headers, **kwargs):
+ return self.model.encode(content, modality=modality, media_download_headers=media_download_headers, **kwargs)
@contextmanager
-def fetch_content_sample(url, sample_size=10240): # 10 KB
- response = requests.get(url, stream=True)
+def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sample_size=10240): # 10 KB
+ # It's ok to pass None to requests.get() for headers and it won't change the default headers
+ response = requests.get(url, stream=True, headers=media_download_headers)
buffer = io.BytesIO()
try:
for chunk in response.iter_content(chunk_size=min(sample_size, 8192)):
@@ -145,7 +147,7 @@ def fetch_content_sample(url, sample_size=10240): # 10 KB
response.close()
-def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
+def infer_modality(content: Union[str, List[str], bytes], media_download_headers: Optional[dict] = None) -> Modality:
"""
Infer the modality of the content. Video, audio, image or text.
"""
@@ -167,7 +169,7 @@ def infer_modality(content: Union[str, List[str], bytes]) -> Modality:
if validate_url(encoded_url):
# Use context manager to handle content sample
try:
- with fetch_content_sample(encoded_url) as sample:
+ with fetch_content_sample(encoded_url, media_download_headers) as sample:
mime = magic.from_buffer(sample.read(), mime=True)
if mime.startswith('image/'):
return Modality.IMAGE
@@ -249,7 +251,7 @@ def preprocessor(self, modality):
return self._preprocessors.get(modality)
- def encode(self, content, modality, normalize=True, **kwargs):
+ def encode(self, content, modality, normalize=True, media_download_headers: Optional[Dict]=None, **kwargs):
inputs = {}
if modality == Modality.TEXT:
@@ -267,7 +269,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
with open(temp_filename, 'wb') as f:
f.write(content)
elif isinstance(content, str) and "http" in content:
- self._download_content(content, temp_filename)
+ self._download_content(content, temp_filename, media_download_headers)
else:
return self.encode([content], modality=Modality.TEXT)
@@ -278,7 +280,7 @@ def encode(self, content, modality, normalize=True, **kwargs):
if isinstance(content, str) and "http" in content:
suffix = ".mp4" if modality == Modality.VIDEO else ".wav"
with self._temp_file(suffix) as temp_filename:
- self._download_content(content, temp_filename)
+ self._download_content(content, temp_filename, media_download_headers)
preprocessed_content = self.preprocessor(modality)([temp_filename], return_tensors='pt')
inputs[modality.value] = to_device(preprocessed_content, self.model.device)['pixel_values']
@@ -300,11 +302,11 @@ def encode(self, content, modality, normalize=True, **kwargs):
return embeddings.cpu().numpy()
- def _download_content(self, url, filename):
+ def _download_content(self, url, filename, media_download_headers: Optional[Dict]=None):
# 3 seconds for images, 20 seconds for audio and video
timeout_ms = 3000 if filename.endswith(('.png', '.jpg', '.jpeg')) else 20000
- buffer = download_image_from_url(url, {}, timeout_ms)
+ buffer = download_image_from_url(url, media_download_headers, timeout_ms)
with open(filename, 'wb') as f:
f.write(buffer.getvalue())
diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py
index fc97d5300..d1b60a606 100644
--- a/src/marqo/s2_inference/s2_inference.py
+++ b/src/marqo/s2_inference/s2_inference.py
@@ -45,10 +45,12 @@
-def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[bytes]],
- model_properties: dict = None,
- device: str = None, normalize_embeddings: bool = get_default_normalization(),
- model_auth: ModelAuth = None, enable_cache: bool = False, modality: Modality = Modality.TEXT, **kwargs,) -> List[List[float]]:
+def vectorise(
+ model_name: str, content: Union[str, List[str], List[Image], List[bytes]],
+ model_properties: dict = None,
+ device: str = None, normalize_embeddings: bool = get_default_normalization(),
+ model_auth: ModelAuth = None, enable_cache: bool = False, modality: Modality = Modality.TEXT,
+ media_download_headers: Optional[Dict] = None, **kwargs) -> List[List[float]]:
if not device:
raise InternalError(message=f"vectorise (internal function) cannot be called without setting device!")
@@ -63,25 +65,36 @@ def vectorise(model_name: str, content: Union[str, List[str], List[Image], List[
model = _available_models[model_cache_key][AvailableModelsKey.model]
if _marqo_inference_cache.is_enabled() and enable_cache:
- return _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers, **kwargs)
else:
- return _vectorise_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_without_cache(model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs)
-def _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
+def _vectorise_with_cache(model, model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs):
if isinstance(content, str):
vectorised = _marqo_inference_cache.get(model_cache_key, content)
if vectorised is None:
- vectorised = _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ vectorised = _encode_without_cache(
+ model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs
+ )
_marqo_inference_cache.set(model_cache_key, content, vectorised[0])
else:
vectorised = _convert_cached_embeddings_to_output(vectorised)
return vectorised
elif isinstance(content, list):
- return _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _vectorise_list_with_cache(
+ model, model_cache_key, content, normalize_embeddings, modality,
+ media_download_headers,
+ **kwargs
+ )
else:
raise TypeError(f"Unsupported content type: {type(content).__name__}")
-def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, **kwargs):
+def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embeddings, modality, media_download_headers,
+ **kwargs):
contents_to_vectorise = []
cached_output = []
@@ -97,7 +110,10 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
contents_to_vectorise.append(content_item)
if contents_to_vectorise:
- vectorised_outputs = _encode_without_cache(model_cache_key, contents_to_vectorise, normalize_embeddings, modality, **kwargs)
+ vectorised_outputs = _encode_without_cache(
+ model_cache_key, contents_to_vectorise, normalize_embeddings, modality,
+ media_download_headers, **kwargs
+ )
# Cache the vectorised outputs
for content_item, vectorised_output in zip(contents_to_vectorise, vectorised_outputs):
if isinstance(content_item, str):
@@ -110,18 +126,25 @@ def _vectorise_list_with_cache(model, model_cache_key, content, normalize_embedd
return vectorised_outputs
-def _vectorise_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
+
+def _vectorise_without_cache(
+ model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
+ normalize_embeddings: bool, modality: Modality, media_download_headers,
+ **kwargs) -> List[List[float]]:
return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
- normalize_embeddings: bool, modality: Modality, **kwargs) -> List[List[float]]:
+ normalize_embeddings: bool, modality: Modality, media_download_headers: Optional[Dict]=None,
+ **kwargs) -> List[List[float]]:
try:
model = _available_models[model_cache_key][AvailableModelsKey.model]
encoder = get_encoder(model)
if isinstance(content, str):
- vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
+ vectorised = model.encode(
+ content, normalize=normalize_embeddings, modality=modality,
+ media_download_headers=media_download_headers, **kwargs
+ )
elif isinstance(content, (torch.Tensor, torch.FloatTensor)):
vectorised = model.encode(content, normalize=normalize_embeddings, modality=modality, **kwargs)
else:
@@ -134,8 +157,9 @@ def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], L
# TODO maybe the infer parameter can be replaced by modality
infer = kwargs.pop('infer', False if modality == Modality.TEXT else True)
- encoded_batch = encoder.encode(batch, modality=modality, normalize=normalize_embeddings,
- infer=infer, **kwargs)
+ encoded_batch = encoder.encode(
+ batch, modality=modality, normalize=normalize_embeddings,
+ media_download_headers=media_download_headers, infer = infer, **kwargs)
vector_batches.append(_convert_tensor_to_numpy(encoded_batch))
From 6e4b924428bb2f8eccb8194744f42ab6e2155f10 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 14:26:43 +1100
Subject: [PATCH 08/29] Fix hybrid
---
src/marqo/core/search/hybrid_search.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/src/marqo/core/search/hybrid_search.py b/src/marqo/core/search/hybrid_search.py
index 3bc2e4ead..2dfd818c3 100644
--- a/src/marqo/core/search/hybrid_search.py
+++ b/src/marqo/core/search/hybrid_search.py
@@ -33,7 +33,7 @@ def search(
offset: int = 0, ef_search: Optional[int] = None, approximate: bool = True,
searchable_attributes: Iterable[str] = None, filter_string: str = None, device: str = None,
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None,
- image_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
+ media_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None,
score_modifiers: Optional[ScoreModifierLists] = None, model_auth: Optional[ModelAuth] = None,
highlights: bool = False, text_query_prefix: Optional[str] = None,
hybrid_parameters: HybridParameters = None) -> Dict:
@@ -51,7 +51,8 @@ def search(
verbose: if 0 - nothing is printed. if 1 - data is printed without vectors, if 2 - full
objects are printed out
attributes_to_retrieve: if set, only returns these fields
- image_download_headers: headers for downloading images
+ media_download_headers: headers for downloading media
+
context: a dictionary to allow custom vectors in search
score_modifiers: a dictionary to modify the score based on field values, should be None for hybrid search
model_auth: Authorisation details for downloading a model (if required)
@@ -151,7 +152,7 @@ def search(
q=query_text_vectorise, searchableAttributes=searchable_attributes, searchMethod=SearchMethod.HYBRID,
limit=result_count,
offset=offset, showHighlights=False, filter=filter_string, attributesToRetrieve=attributes_to_retrieve,
- boost=boost, image_download_headers=image_download_headers, context=context, scoreModifiers=score_modifiers,
+ boost=boost, media_download_headers=media_download_headers, context=context, scoreModifiers=score_modifiers,
index=marqo_index, modelAuth=model_auth, text_query_prefix=text_query_prefix,
hybridParameters=hybrid_parameters
)]
From 0f84ba6b8cbb5700d8763df6db864b9f9d47140b Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 14:43:30 +1100
Subject: [PATCH 09/29] Fix hybrid tests
---
src/marqo/core/search/hybrid_search.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/marqo/core/search/hybrid_search.py b/src/marqo/core/search/hybrid_search.py
index 2dfd818c3..9ee8a5264 100644
--- a/src/marqo/core/search/hybrid_search.py
+++ b/src/marqo/core/search/hybrid_search.py
@@ -152,7 +152,7 @@ def search(
q=query_text_vectorise, searchableAttributes=searchable_attributes, searchMethod=SearchMethod.HYBRID,
limit=result_count,
offset=offset, showHighlights=False, filter=filter_string, attributesToRetrieve=attributes_to_retrieve,
- boost=boost, media_download_headers=media_download_headers, context=context, scoreModifiers=score_modifiers,
+ boost=boost, mediaDownloadHeaders=media_download_headers, context=context, scoreModifiers=score_modifiers,
index=marqo_index, modelAuth=model_auth, text_query_prefix=text_query_prefix,
hybridParameters=hybrid_parameters
)]
From 8afca5f24f05611744d72ee77cb0d55b9079bfc9 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 15:16:57 +1100
Subject: [PATCH 10/29] Fix embed
---
src/marqo/api/models/embed_request.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/src/marqo/api/models/embed_request.py b/src/marqo/api/models/embed_request.py
index ff16f6a3a..27bee6d8f 100644
--- a/src/marqo/api/models/embed_request.py
+++ b/src/marqo/api/models/embed_request.py
@@ -9,18 +9,23 @@
from pydantic import Field, root_validator
from marqo.tensor_search.models.private_models import ModelAuth
-from marqo.tensor_search.models.api_models import BaseMarqoModel
+from marqo.base_model import MarqoBaseModel
from marqo.core.embed.embed import EmbedContentType
-class EmbedRequest(BaseMarqoModel):
+class EmbedRequest(MarqoBaseModel):
# content can be a single query or list of queries. Queries can be a string or a dictionary.
content: Union[str, Dict[str, float], List[Union[str, Dict[str, float]]]]
image_download_headers: Optional[Dict] = Field(default=None, alias="imageDownloadHeaders")
- mediaDownloadHeaders: Optional[Dict] = Field(default=None, alias="mediaDownloadHeaders")
+ mediaDownloadHeaders: Optional[Dict] = None
modelAuth: Optional[ModelAuth] = None
- content_type: Optional[EmbedContentType] = Field(EmbedContentType.Query, alias=("contentType"))
+ content_type: Optional[EmbedContentType] = Field(default=EmbedContentType.Query, alias="contentType")
+
+ @root_validator(pre=True)
+ def _test(cls, values):
+ print(values)
+ return values
@pydantic.validator('content')
def validate_content(cls, value):
From 414df743d7c992a3b2240157306401ceab82631b Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 16:14:37 +1100
Subject: [PATCH 11/29] Fix embed
---
.../semi_structured_add_document_handler.py | 5 +-
.../unstructured_add_document_handler.py | 5 +-
.../s2_inference/multimodal_model_load.py | 8 ++-
src/marqo/tensor_search/tensor_search.py | 2 +-
.../test_add_documents_combined.py | 61 ++++++++++++++++++-
5 files changed, 75 insertions(+), 6 deletions(-)
diff --git a/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py b/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py
index c82ae3fc1..74ffb4073 100644
--- a/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py
+++ b/src/marqo/core/semi_structured_vespa_index/semi_structured_add_document_handler.py
@@ -41,7 +41,10 @@ def __init__(self, marqo_index: SemiStructuredMarqoIndex, add_docs_params: AddDo
def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
- text_field_type = self._infer_field_type(field_content)
+ text_field_type = self._infer_field_type(
+ field_content,
+ media_download_headers=self.add_docs_params.media_download_headers
+ )
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name,
field_content, text_field_type)
marqo_doc[field_name] = content
diff --git a/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py b/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
index 31c3b300c..c9f89ccd9 100644
--- a/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
+++ b/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
@@ -70,12 +70,13 @@ def _handle_field(self, marqo_doc, field_name, field_content):
field_content, text_field_type)
marqo_doc[field_name] = content
- def _infer_field_type(self, field_content: Any) -> Optional[FieldType]:
+ def _infer_field_type(self, field_content: Any, media_download_headers: Optional[Dict] = None) \
+ -> Optional[FieldType]:
if not isinstance(field_content, str):
return None
try:
- modality = infer_modality(field_content)
+ modality = infer_modality(field_content, media_download_headers)
if not self.marqo_index.treat_urls_and_pointers_as_media and modality in [Modality.AUDIO, Modality.VIDEO]:
modality = Modality.TEXT
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 61e4992c6..3a73326b3 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -133,7 +133,14 @@ def encode(self, content, modality, media_download_headers, **kwargs):
@contextmanager
def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sample_size=10240): # 10 KB
# It's ok to pass None to requests.get() for headers and it won't change the default headers
+ """Fetch a sample of the content from the URL.
+
+ Raises:
+ HTTPError: If the response status code is not 200
+ """
response = requests.get(url, stream=True, headers=media_download_headers)
+ if response.status_code != 200:
+ response.raise_for_status()
buffer = io.BytesIO()
try:
for chunk in response.iter_content(chunk_size=min(sample_size, 8192)):
@@ -157,7 +164,6 @@ def infer_modality(content: Union[str, List[str], bytes], media_download_headers
# Encode the URL
encoded_url = encode_url(content)
-
extension = encoded_url.split('.')[-1].lower()
if extension in ['jpg', 'jpeg', 'png', 'gif', 'webp']:
return Modality.IMAGE
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index 654f3a292..e84201d03 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -2036,7 +2036,7 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear
# Apply prefix if key is not an image or if index does not treat URLs and pointers as images
modality = infer_modality(key, q.mediaDownloadHeaders)
if modality == Modality.TEXT:
- prefixed_q[key] = f"{text_query_prefix}{value}"
+ prefixed_q[f"{text_query_prefix}{key}"] = value
else:
prefixed_q[key] = value
new_query_object = BulkSearchQueryEntity(
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 21c3b90b4..417dc3028 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -15,6 +15,7 @@
import requests
import torch
from more_itertools import flatten
+from numpy.ma.core import subtract
from torch import Tensor
import unittest.mock
@@ -1092,4 +1093,62 @@ def test_textIndexEmbeddingsUnnormalized(self):
embeddings = get_res['results'][0]['_tensor_facets'][0]['_embedding']
norm = np.linalg.norm(np.array(embeddings))
- self.assertTrue(norm - 1.0 > 1e-5, f"Embedding norm is {norm}")
\ No newline at end of file
+ self.assertTrue(norm - 1.0 > 1e-5, f"Embedding norm is {norm}")
+
+ def test_add_private_images_proper_error_returned(self):
+ """Test to ensure that private images can not be downloaded and an appropriate error is returned"""
+ test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
+ documents = [
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "_id": "1"
+ },
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "_id": "2"
+ }
+ ]
+ for index_name in test_indexes:
+ tensor_fields = ["image_field_1"] if index_name == self.unstructured_marqo_index_name else None
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields
+ )
+ )
+ self.assertTrue(res.errors)
+ items = res.items
+ self.assertEqual(2, len(items))
+ for item in items:
+ self.assertEqual(400, item.status)
+ self.assertIn("403", item.message)
+
+ def test_add_private_images_success(self):
+ """Test to ensure that private images can be downloaded with proper headers"""
+ test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
+ documents = [
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "_id": "1"
+ },
+ {
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "_id": "2"
+ }
+ ]
+ for index_name in test_indexes:
+ tensor_fields = ["image_field_1"] if index_name == self.unstructured_marqo_index_name else None
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields,
+ media_download_headers={"marqo_media_header": "media_header_test_key"}
+ )
+ )
+ self.assertFalse(res.errors)
\ No newline at end of file
From b5e2195d1e333b8b07fead14141e7bb12aa7724c Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 16:39:42 +1100
Subject: [PATCH 12/29] Add add_documents tests and search tests
---
src/marqo/s2_inference/s2_inference.py | 2 +-
.../integ_tests/test_search_combined.py | 62 ++++++++++++++++---
2 files changed, 55 insertions(+), 9 deletions(-)
diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py
index d1b60a606..ce01848f5 100644
--- a/src/marqo/s2_inference/s2_inference.py
+++ b/src/marqo/s2_inference/s2_inference.py
@@ -131,7 +131,7 @@ def _vectorise_without_cache(
model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
normalize_embeddings: bool, modality: Modality, media_download_headers,
**kwargs) -> List[List[float]]:
- return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, **kwargs)
+ return _encode_without_cache(model_cache_key, content, normalize_embeddings, modality, media_download_headers, **kwargs)
def _encode_without_cache(model_cache_key: str, content: Union[str, List[str], List[Image], List[bytes]],
normalize_embeddings: bool, modality: Modality, media_download_headers: Optional[Dict]=None,
diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py
index e5e26674f..a87e7cabf 100644
--- a/tests/tensor_search/integ_tests/test_search_combined.py
+++ b/tests/tensor_search/integ_tests/test_search_combined.py
@@ -1,23 +1,25 @@
import os
import uuid
from unittest import mock
-import torch
+
import pytest
+import torch
+from pydantic import ValidationError
import marqo.core.exceptions as core_exceptions
+from marqo import exceptions as base_exceptions
+from marqo.core.models.add_docs_params import AddDocsParams
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import FieldRequest
-from marqo.tensor_search import tensor_search
-from marqo.tensor_search.enums import SearchMethod
-from marqo.core.models.add_docs_params import AddDocsParams
-from tests.marqo_test import MarqoTestCase, TestImageUrls
-from marqo import exceptions as base_exceptions
from marqo.core.models.marqo_query import MarqoLexicalQuery
from marqo.core.models.score_modifier import ScoreModifierType, ScoreModifier
from marqo.core.structured_vespa_index.structured_vespa_index import StructuredVespaIndex
from marqo.core.unstructured_vespa_index.unstructured_vespa_index import UnstructuredVespaIndex
+from marqo.s2_inference.errors import MediaDownloadError
+from marqo.tensor_search import tensor_search
+from marqo.tensor_search.enums import SearchMethod
from marqo.tensor_search.models.api_models import SearchQuery
-from pydantic import ValidationError
+from tests.marqo_test import MarqoTestCase, TestImageUrls
class TestSearch(MarqoTestCase):
@@ -965,4 +967,48 @@ def test_search_query_CanAcceptDifferentSearchMethods(self):
# A special case for no search method provided
search_query = SearchQuery(q="test")
- self.assertEqual(SearchMethod.TENSOR, search_query.searchMethod)
\ No newline at end of file
+ self.assertEqual(SearchMethod.TENSOR, search_query.searchMethod)
+
+ def test_search_private_images_proper_error_raised(self):
+ """Test that search raises a MediaDownloadError when trying to access private images"""
+ test_indexes = [
+ self.unstructured_default_image_index,
+ self.structured_default_image_index
+ ]
+
+ test_queries = [({
+ "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png": 1,
+ "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1 }, "dictionary queries"),
+ ("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "str queries")]
+ for index_name in test_indexes:
+ for query, msg in test_queries:
+ with self.subTest(msg=f"index: {index_name}, query: {msg}"):
+ with self.assertRaises(MediaDownloadError):
+ _ = tensor_search.search(
+ config=self.config,
+ index_name=index_name.name,
+ text=query,
+ search_method=SearchMethod.TENSOR,
+ )
+
+ def test_search_over_private_images_with_media_download_headers(self):
+ """Test that search can use private images with media download headers"""
+ test_indexes = [
+ self.unstructured_default_image_index,
+ self.structured_default_image_index
+ ]
+
+ test_queries = [({
+ "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png": 1,
+ "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1 }, "dictionary queries"),
+ ("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "str queries")]
+ for index_name in test_indexes:
+ for query, msg in test_queries:
+ with self.subTest(msg=f"index: {index_name}, query: {msg}"):
+ _ = tensor_search.search(
+ config=self.config,
+ index_name=index_name.name,
+ text=query,
+ search_method=SearchMethod.TENSOR,
+ media_download_headers={"marqo_media_header": "media_header_test_key"}
+ )
\ No newline at end of file
From 9b2a08a0adbc4328937a69c84096619ba872f652 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 17:12:20 +1100
Subject: [PATCH 13/29] Respond to Farshid's comments
---
src/marqo/api/models/add_docs_objects.py | 2 +-
src/marqo/api/models/embed_request.py | 7 +---
.../s2_inference/multimodal_model_load.py | 3 +-
src/marqo/tensor_search/models/api_models.py | 2 +-
src/marqo/tensor_search/tensor_search.py | 2 +-
.../tensor_search/test_modalities_download.py | 32 +++++++++----------
6 files changed, 21 insertions(+), 27 deletions(-)
diff --git a/src/marqo/api/models/add_docs_objects.py b/src/marqo/api/models/add_docs_objects.py
index 2174753e6..ad5c2b81d 100644
--- a/src/marqo/api/models/add_docs_objects.py
+++ b/src/marqo/api/models/add_docs_objects.py
@@ -53,7 +53,7 @@ def _validate_image_download_headers_and_media_download_headers(cls, values):
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
- "The imageDownloadHeaders is deprecated and will be removed in the future. "
+ "'imageDownloadHeaders' is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
diff --git a/src/marqo/api/models/embed_request.py b/src/marqo/api/models/embed_request.py
index 27bee6d8f..c1373da6d 100644
--- a/src/marqo/api/models/embed_request.py
+++ b/src/marqo/api/models/embed_request.py
@@ -22,11 +22,6 @@ class EmbedRequest(MarqoBaseModel):
modelAuth: Optional[ModelAuth] = None
content_type: Optional[EmbedContentType] = Field(default=EmbedContentType.Query, alias="contentType")
- @root_validator(pre=True)
- def _test(cls, values):
- print(values)
- return values
-
@pydantic.validator('content')
def validate_content(cls, value):
# Iterate through content list items
@@ -70,7 +65,7 @@ def _validate_image_download_headers_and_media_download_headers(cls, values):
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
raise ValueError("Cannot set both imageDownloadHeaders and mediaDownloadHeaders. "
- "The imageDownloadHeaders is deprecated and will be removed in the future. "
+ "'imageDownloadHeaders' is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 3a73326b3..2dc6da1bd 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -139,8 +139,7 @@ def fetch_content_sample(url, media_download_headers: Optional[dict] = None, sam
HTTPError: If the response status code is not 200
"""
response = requests.get(url, stream=True, headers=media_download_headers)
- if response.status_code != 200:
- response.raise_for_status()
+ response.raise_for_status()
buffer = io.BytesIO()
try:
for chunk in response.iter_content(chunk_size=min(sample_size, 8192)):
diff --git a/src/marqo/tensor_search/models/api_models.py b/src/marqo/tensor_search/models/api_models.py
index d688e55f0..3f4bccd97 100644
--- a/src/marqo/tensor_search/models/api_models.py
+++ b/src/marqo/tensor_search/models/api_models.py
@@ -82,7 +82,7 @@ def _validate_image_download_headers_and_media_download_headers(cls, values):
media_download_headers = values.get('mediaDownloadHeaders')
if image_download_headers and media_download_headers:
raise ValueError("Cannot set both imageDownloadHeaders(image_download_headers) and mediaDownloadHeaders. "
- "The imageDownloadHeaders(image_download_headers) is deprecated and will be removed in the future. "
+ "'imageDownloadHeaders'(image_download_headers) is deprecated and will be removed in the future. "
"Use mediaDownloadHeaders instead.")
if image_download_headers:
values['mediaDownloadHeaders'] = image_download_headers
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index e84201d03..9e7381379 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -2596,7 +2596,7 @@ def vectorise_multimodal_combination_field_structured(
content=prefixed_text_content,
device=device,
normalize_embeddings=normalize_embeddings,
- infer=True,
+ infer=False,
model_auth=model_auth,
modality=Modality.TEXT
)
diff --git a/tests/tensor_search/test_modalities_download.py b/tests/tensor_search/test_modalities_download.py
index 55142a5cf..b7158b2be 100644
--- a/tests/tensor_search/test_modalities_download.py
+++ b/tests/tensor_search/test_modalities_download.py
@@ -62,7 +62,7 @@ def test_image_unstructured_index(self, mock_infer_modality, mock_load_image):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -83,7 +83,7 @@ def test_image_structured_index(self, mock_infer_modality, mock_load_image):
media_field_types_mapping = {"field1": FieldType.ImagePointer}
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -106,7 +106,7 @@ def test_video_unstructured_index(self, mock_infer_modality, mock_download_and_c
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -130,7 +130,7 @@ def test_audio_structured_index(self, mock_infer_modality, mock_download_and_chu
media_field_types_mapping = {"field1": FieldType.AudioPointer}
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -148,7 +148,7 @@ def test_unsupported_modality(self, mock_infer_modality):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -167,7 +167,7 @@ def test_image_load_error(self, mock_infer_modality, mock_load_image):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -189,7 +189,7 @@ def test_video_processing_error(self, mock_infer_modality, mock_download_and_chu
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -217,7 +217,7 @@ def test_video_and_audio_unstructured_index(self, mock_infer_modality, mock_down
# Call the function
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -233,10 +233,10 @@ def test_video_and_audio_unstructured_index(self, mock_infer_modality, mock_down
# Verify the calls to download_and_chunk_media
mock_download_and_chunk.assert_any_call(
- self.mock_video_url, "cpu", None, Modality.VIDEO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
+ self.mock_video_url, "cpu", {}, Modality.VIDEO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
)
mock_download_and_chunk.assert_any_call(
- self.mock_audio_url, "cpu", None, Modality.AUDIO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
+ self.mock_audio_url, "cpu", {}, Modality.AUDIO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
)
@patch("marqo.tensor_search.add_docs.download_and_chunk_media")
@@ -261,7 +261,7 @@ def test_mismatched_media_fields(self, mock_infer_modality, mock_download_and_ch
]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -291,7 +291,7 @@ def test_invalid_media_fields(self, mock_infer_modality):
mock_infer_modality.side_effect = [Modality.TEXT, Modality.TEXT]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -321,7 +321,7 @@ def test_ffmpeg_error_handling(self, mock_infer_modality, mock_download_and_chun
mock_download_and_chunk.side_effect = ffmpeg.Error("FFmpeg processing error", stdout=b"", stderr=b"")
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -347,7 +347,7 @@ def test_valid_image_processing(self, mock_infer_modality, mock_load_image):
media_field_types_mapping = {"image_field": FieldType.ImagePointer}
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
media_field_types_mapping=media_field_types_mapping
@@ -365,7 +365,7 @@ def test_media_download_error(self, mock_infer_modality):
tensor_fields = ["field1"]
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
@@ -392,7 +392,7 @@ def test_audio_with_video_only_model(self, mock_infer_modality, mock_download_an
# Call the function
threaded_download_and_preprocess_content(
- docs, media_repo, tensor_fields, {}, device="cpu",
+ docs, media_repo, tensor_fields, media_download_headers={}, device="cpu",
marqo_index_type=self.mock_marqo_index.type,
marqo_index_model=self.mock_marqo_index.model,
)
From 31048f070dcfadfff7b48db54a78848a15c29234 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 17:35:23 +1100
Subject: [PATCH 14/29] Replace all the image_download_headers with
media_download_headers
---
.../embedding_models/abstract_clip_model.py | 20 +-
.../embedding_models/image_download.py | 236 ------------------
.../embedding_models/open_clip_model.py | 4 +-
src/marqo/core/inference/image_download.py | 20 +-
src/marqo/s2_inference/clip_utils.py | 54 ++--
src/marqo/tensor_search/tensor_search.py | 4 +-
tests/s2_inference/test_image_downloading.py | 8 +-
.../test_add_documents_combined.py | 10 +-
tests/tensor_search/integ_tests/test_embed.py | 10 +-
...test_add_documents_use_existing_tensors.py | 2 +-
tests/tensor_search/test_api_utils.py | 18 +-
.../test_image_download_headers.py | 24 +-
tests/tensor_search/test_search.py | 2 +-
13 files changed, 88 insertions(+), 324 deletions(-)
delete mode 100644 src/marqo/core/inference/embedding_models/image_download.py
diff --git a/src/marqo/core/inference/embedding_models/abstract_clip_model.py b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
index 43eb3a849..42b8c2d8c 100644
--- a/src/marqo/core/inference/embedding_models/abstract_clip_model.py
+++ b/src/marqo/core/inference/embedding_models/abstract_clip_model.py
@@ -7,7 +7,7 @@
from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.core.inference.embedding_models.abstract_embedding_model import AbstractEmbeddingModel
-from marqo.core.inference.embedding_models.image_download import (_is_image, format_and_load_CLIP_images,
+from marqo.core.inference.image_download import (_is_image, format_and_load_CLIP_images,
format_and_load_CLIP_image)
from marqo.s2_inference.logger import get_logger
from marqo.s2_inference.types import *
@@ -50,7 +50,7 @@ def encode_text(self, inputs: Union[str, List[str]], normalize: bool = True) ->
pass
@abstractmethod
- def encode_image(self, inputs, normalize: bool = True, image_download_headers: dict = None) -> np.ndarray:
+ def encode_image(self, inputs, normalize: bool = True, media_download_headers: dict = None) -> np.ndarray:
pass
def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], normalize=True, **kwargs) -> np.ndarray:
@@ -68,8 +68,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], nor
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("media_download_headers", dict())
- return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
+ media_download_headers = kwargs.get("media_download_headers", dict())
+ return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
@@ -85,27 +85,27 @@ def normalize(outputs):
return outputs.norm(dim=-1, keepdim=True)
def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
- image_download_headers: Optional[Dict] = None) -> Tensor:
+ media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
- image_download_headers (Optional[Dict]): headers for the image download
+ media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
- if image_download_headers is None:
- image_download_headers = dict()
+ if media_download_headers is None:
+ media_download_headers = dict()
# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
- = format_and_load_CLIP_images(images, image_download_headers)
+ = format_and_load_CLIP_images(images, media_download_headers)
else:
- image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
+ image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]
image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
diff --git a/src/marqo/core/inference/embedding_models/image_download.py b/src/marqo/core/inference/embedding_models/image_download.py
deleted file mode 100644
index 65c158e20..000000000
--- a/src/marqo/core/inference/embedding_models/image_download.py
+++ /dev/null
@@ -1,236 +0,0 @@
-import os
-from io import BytesIO
-
-import certifi
-import numpy as np
-import pycurl
-import requests
-import torch
-import validators
-from PIL import Image, UnidentifiedImageError
-from requests.utils import requote_uri
-
-from marqo import marqo_docs
-from marqo.api.exceptions import InternalError
-from marqo.s2_inference.errors import ImageDownloadError
-from marqo.s2_inference.types import *
-from marqo.tensor_search.telemetry import RequestMetrics
-
-# TODO Merge this with the one in clip_utils in the future refactoring
-
-DEFAULT_HEADERS = {'User-Agent': 'Marqobot/1.0'}
-
-
-def get_allowed_image_types():
- return {'.jpg', '.png', '.bmp', '.jpeg'}
-
-
-def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
- # some logic to determine if something is an image or not
- # assume the batch is the same type
- # maybe we use something like this https://github.com/ahupp/python-magic
-
- _allowed = get_allowed_image_types()
-
- # we assume the batch is this way if a list
- # otherwise apply over each element
- if isinstance(inputs, list):
-
- if len(inputs) == 0:
- raise UnidentifiedImageError("received empty list, expected at least one element.")
-
- thing = inputs[0]
- else:
- thing = inputs
-
- # if it is a string, determine if it is a local file or url
- if isinstance(thing, str):
- name, extension = os.path.splitext(thing.lower())
-
- # if it has the correct extension, asssume yes
- if extension in _allowed:
- return True
-
- # if it is a local file without extension, then raise an error
- if os.path.isfile(thing):
- # we could also read the first part of the file and infer
- raise UnidentifiedImageError(
- f"local file [{thing}] extension {extension} does not match allowed file types of {_allowed}")
- else:
- # if it is not a local file and does not have an extension
- # check if url
- if validators.url(thing):
- return True
- else:
- return False
-
- # if it is an array, then it is an image
- elif isinstance(thing, (ImageType, ndarray, Tensor)):
- return True
- else:
- raise UnidentifiedImageError(f"expected type Image or str for inputs but received type {type(thing)}")
-
-
-def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
- ImageType]:
- """takes in a list of strings, arrays or urls and either loads and/or converts to PIL
- for the clip model
-
- Args:
- images (List[Union[str, np.ndarray, ImageType]]): list of file locations or arrays (can be mixed)
-
- Raises:
- TypeError: _description_
-
- Returns:
- List[ImageType]: list of PIL images
- """
- if not isinstance(images, list):
- raise TypeError(f"expected list but received {type(images)}")
-
- results = []
- for image in images:
- results.append(format_and_load_CLIP_image(image, image_download_headers))
-
- return results
-
-
-def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
- image_download_headers: dict) -> Union[ImageType, Tensor]:
- """standardizes the input to be a PIL image
-
- Args:
- image (Union[str, np.ndarray, ImageType, Tensor]): can be a local file, url, array or a tensor
-
- Raises:
- ValueError: _description_
- TypeError: _description_
-
- Returns:
- standardized the image:
- ImageType: PIL image if input is a string, an array or a PIL image
- Tensor: torch tensor if input is a torch tensor
- """
- # check for the input type
- if isinstance(image, str):
- img = load_image_from_path(image, image_download_headers)
- elif isinstance(image, np.ndarray):
- img = Image.fromarray(image.astype('uint8'), 'RGB')
- elif isinstance(image, torch.Tensor):
- img = image
- elif isinstance(image, ImageType):
- img = image
- else:
- raise UnidentifiedImageError(f"input of type {type(image)} "
- f"did not match allowed types of str, np.ndarray, ImageType, Tensor")
-
- return img
-
-
-def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
- metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
- """Loads an image into PIL from a string path that is either local or a url
-
- Args:
- image_path (str): Local or remote path to image.
- image_download_headers (dict): header for the image download
- timeout_ms (int): timeout (in milliseconds), for the whole request
- Raises:
- ValueError: If the local path is invalid, and is not a url
- UnidentifiedImageError: If the image is irretrievable or unprocessable.
-
- Returns:
- ImageType: In-memory PIL image.
- """
- if os.path.isfile(image_path):
- img = Image.open(image_path)
- elif validators.url(image_path):
- if metrics_obj is not None:
- metrics_obj.start(f"image_download.{image_path}")
- try:
- img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
- img = Image.open(img_io)
- except ImageDownloadError as e:
- raise UnidentifiedImageError(str(e)) from e
- finally:
- if metrics_obj is not None:
- metrics_obj.stop(f"image_download.{image_path}")
- else:
- raise UnidentifiedImageError(f"Input str of {image_path} is not a local file or a valid url. "
- f"If you are using Marqo Cloud, please note that images can only be downloaded "
- f"from a URL and local files are not supported. "
- f"If you are running Marqo in a Docker container, you will need to use a Docker "
- f"volume so that your container can access host files. "
- f"For more information, please refer to: "
- f"{marqo_docs.indexing_images()}")
-
- return img
-
-
-def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
- """Download an image from a URL and return a PIL image using pycurl.
-
- Args:
- image_path (str): URL to the image.
- image_download_headers (dict): Headers for the image download.
- timeout_ms (int): Timeout in milliseconds, for the whole request.
-
- Returns:
- buffer (BytesIO): The image as a BytesIO object.
-
- Raises:
- ImageDownloadError: If the image download fails.
- """
-
- if not isinstance(timeout_ms, int):
- raise InternalError(f"timeout must be an integer but received {timeout_ms} of type {type(timeout_ms)}")
-
- try:
- encoded_url = encode_url(image_path)
- except UnicodeEncodeError as e:
- raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
- f"The url could not be encoded properly. Original error: {e}")
- buffer = BytesIO()
- c = pycurl.Curl()
- c.setopt(pycurl.CAINFO, certifi.where())
- c.setopt(pycurl.URL, encoded_url)
- c.setopt(pycurl.WRITEDATA, buffer)
- c.setopt(pycurl.TIMEOUT_MS, timeout_ms)
- c.setopt(pycurl.FOLLOWLOCATION, 1)
-
- headers = DEFAULT_HEADERS.copy()
- headers.update(image_download_headers)
- c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
-
- try:
- c.perform()
- if c.getinfo(pycurl.RESPONSE_CODE) != 200:
- raise ImageDownloadError(f"image url `{image_path}` returned {c.getinfo(pycurl.RESPONSE_CODE)}")
- except pycurl.error as e:
- raise ImageDownloadError(f"Marqo encountered an error when downloading the image url {image_path}. "
- f"The original error is: {e}")
- finally:
- c.close()
- buffer.seek(0)
- return buffer
-
-
-def encode_url(url: str) -> str:
- """
- Encode a URL to a valid format with only ASCII characters and reserved characters using percent-encoding.
-
- In version 2.8, we replaced the requests library with pycurl for image downloads. Consequently, we need to implement
- the URL encoding function ourselves. This function replicates the encoding behavior of the
- 'requests.utils.requote_uri' function from the requests library.
-
- Args:
- url (str): The URL to encode.
-
- Returns:
- str: The encoded URL.
-
- Raises:
- UnicodeEncodeError: If the URL cannot be encoded properly.
-
- """
- return requests.utils.requote_uri(url)
diff --git a/src/marqo/core/inference/embedding_models/open_clip_model.py b/src/marqo/core/inference/embedding_models/open_clip_model.py
index e79cb9feb..fdc050316 100644
--- a/src/marqo/core/inference/embedding_models/open_clip_model.py
+++ b/src/marqo/core/inference/embedding_models/open_clip_model.py
@@ -247,10 +247,10 @@ def _download_from_repo(self):
return model_file_path
def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
- image_download_headers: Optional[Dict] = None,
+ media_download_headers: Optional[Dict] = None,
normalize=True) -> FloatTensor:
- self.image_input_processed: Tensor = self._preprocess_images(images, image_download_headers)
+ self.image_input_processed: Tensor = self._preprocess_images(images, media_download_headers)
with torch.no_grad():
if self.device.startswith("cuda"):
diff --git a/src/marqo/core/inference/image_download.py b/src/marqo/core/inference/image_download.py
index 65c158e20..9cebb5948 100644
--- a/src/marqo/core/inference/image_download.py
+++ b/src/marqo/core/inference/image_download.py
@@ -71,7 +71,7 @@ def _is_image(inputs: Union[str, List[Union[str, ImageType, ndarray]]]) -> bool:
raise UnidentifiedImageError(f"expected type Image or str for inputs but received type {type(thing)}")
-def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
+def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], media_download_headers: dict) -> List[
ImageType]:
"""takes in a list of strings, arrays or urls and either loads and/or converts to PIL
for the clip model
@@ -90,13 +90,13 @@ def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], im
results = []
for image in images:
- results.append(format_and_load_CLIP_image(image, image_download_headers))
+ results.append(format_and_load_CLIP_image(image, media_download_headers))
return results
def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
- image_download_headers: dict) -> Union[ImageType, Tensor]:
+ media_download_headers: dict) -> Union[ImageType, Tensor]:
"""standardizes the input to be a PIL image
Args:
@@ -113,7 +113,7 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
"""
# check for the input type
if isinstance(image, str):
- img = load_image_from_path(image, image_download_headers)
+ img = load_image_from_path(image, media_download_headers)
elif isinstance(image, np.ndarray):
img = Image.fromarray(image.astype('uint8'), 'RGB')
elif isinstance(image, torch.Tensor):
@@ -127,13 +127,13 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
return img
-def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
+def load_image_from_path(image_path: str, media_download_headers: dict, timeout_ms=3000,
metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
"""Loads an image into PIL from a string path that is either local or a url
Args:
image_path (str): Local or remote path to image.
- image_download_headers (dict): header for the image download
+ media_download_headers (dict): header for the image download
timeout_ms (int): timeout (in milliseconds), for the whole request
Raises:
ValueError: If the local path is invalid, and is not a url
@@ -148,7 +148,7 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
if metrics_obj is not None:
metrics_obj.start(f"image_download.{image_path}")
try:
- img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
+ img_io: BytesIO = download_image_from_url(image_path, media_download_headers, timeout_ms)
img = Image.open(img_io)
except ImageDownloadError as e:
raise UnidentifiedImageError(str(e)) from e
@@ -167,12 +167,12 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
return img
-def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
+def download_image_from_url(image_path: str, media_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
"""Download an image from a URL and return a PIL image using pycurl.
Args:
image_path (str): URL to the image.
- image_download_headers (dict): Headers for the image download.
+ media_download_headers (dict): Headers for the image download.
timeout_ms (int): Timeout in milliseconds, for the whole request.
Returns:
@@ -199,7 +199,7 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
- headers.update(image_download_headers)
+ headers.update(media_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
try:
diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py
index ff787dd08..d1fd5c684 100644
--- a/src/marqo/s2_inference/clip_utils.py
+++ b/src/marqo/s2_inference/clip_utils.py
@@ -67,7 +67,7 @@ def _get_transform(n_px: int, image_mean: List[float] = None, image_std: List[fl
])
-def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], image_download_headers: dict) -> List[
+def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], media_download_headers: dict) -> List[
ImageType]:
"""takes in a list of strings, arrays or urls and either loads and/or converts to PIL
for the clip model
@@ -86,18 +86,18 @@ def format_and_load_CLIP_images(images: List[Union[str, ndarray, ImageType]], im
results = []
for image in images:
- results.append(format_and_load_CLIP_image(image, image_download_headers))
+ results.append(format_and_load_CLIP_image(image, media_download_headers))
return results
-def load_image_from_path(image_path: str, image_download_headers: dict, timeout_ms=3000,
+def load_image_from_path(image_path: str, media_download_headers: dict, timeout_ms=3000,
metrics_obj: Optional[RequestMetrics] = None) -> ImageType:
"""Loads an image into PIL from a string path that is either local or a url
Args:
image_path (str): Local or remote path to image.
- image_download_headers (dict): header for the image download
+ media_download_headers (dict): header for the image download
timeout_ms (int): timeout (in milliseconds), for the whole request
Raises:
ValueError: If the local path is invalid, and is not a url
@@ -112,7 +112,7 @@ def load_image_from_path(image_path: str, image_download_headers: dict, timeout_
if metrics_obj is not None:
metrics_obj.start(f"image_download.{image_path}")
try:
- img_io: BytesIO = download_image_from_url(image_path, image_download_headers, timeout_ms)
+ img_io: BytesIO = download_image_from_url(image_path, media_download_headers, timeout_ms)
img = Image.open(img_io)
except ImageDownloadError as e:
raise UnidentifiedImageError(str(e)) from e
@@ -145,12 +145,12 @@ def validate_url(url: str) -> bool:
-def download_image_from_url(image_path: str, image_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
+def download_image_from_url(image_path: str, media_download_headers: dict, timeout_ms: int = 3000) -> BytesIO:
"""Download an image from a URL and return a PIL image using pycurl.
Args:
image_path (str): URL to the image.
- image_download_headers (dict): Headers for the image download.
+ media_download_headers (dict): Headers for the image download.
timeout_ms (int): Timeout in milliseconds, for the whole request.
Returns:
@@ -177,9 +177,9 @@ def download_image_from_url(image_path: str, image_download_headers: dict, timeo
c.setopt(pycurl.FOLLOWLOCATION, 1)
headers = DEFAULT_HEADERS.copy()
- if image_download_headers is None:
- image_download_headers = dict()
- headers.update(image_download_headers)
+ if media_download_headers is None:
+ media_download_headers = dict()
+ headers.update(media_download_headers)
c.setopt(pycurl.HTTPHEADER, [f"{k}: {v}" for k, v in headers.items()])
try:
@@ -217,7 +217,7 @@ def encode_url(url: str) -> str:
def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
- image_download_headers: dict) -> Union[ImageType, Tensor]:
+ media_download_headers: dict) -> Union[ImageType, Tensor]:
"""standardizes the input to be a PIL image
Args:
@@ -234,7 +234,7 @@ def format_and_load_CLIP_image(image: Union[str, ndarray, ImageType, Tensor],
"""
# check for the input type
if isinstance(image, str):
- img = load_image_from_path(image, image_download_headers)
+ img = load_image_from_path(image, media_download_headers)
elif isinstance(image, np.ndarray):
img = Image.fromarray(image.astype('uint8'), 'RGB')
elif isinstance(image, torch.Tensor):
@@ -420,27 +420,27 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT
return self._convert_output(outputs)
def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
- image_download_headers: Optional[Dict] = None) -> Tensor:
+ media_download_headers: Optional[Dict] = None) -> Tensor:
"""Preprocess the input image to be ready for the model.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
- image_download_headers (Optional[Dict]): headers for the image download
+ media_download_headers (Optional[Dict]): headers for the image download
Return:
Tensor: the processed image tensor with shape (batch_size, channel, n_px, n_px)
"""
if self.model is None:
self.load()
- if image_download_headers is None:
- image_download_headers = dict()
+ if media_download_headers is None:
+ media_download_headers = dict()
# default to batch encoding
if isinstance(images, list):
image_input: List[Union[ImageType, Tensor]] \
- = format_and_load_CLIP_images(images, image_download_headers)
+ = format_and_load_CLIP_images(images, media_download_headers)
else:
- image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, image_download_headers)]
+ image_input: List[Union[ImageType, Tensor]] = [format_and_load_CLIP_image(images, media_download_headers)]
image_input_processed: Tensor = torch.stack([self.preprocess(_img).to(self.device) \
if not isinstance(_img, torch.Tensor) else _img \
@@ -448,18 +448,18 @@ def _preprocess_images(self, images: Union[str, ImageType, List[Union[str, Image
return image_input_processed
def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor],
- normalize=True, image_download_headers: Optional[Dict] = None) -> FloatTensor:
+ normalize=True, media_download_headers: Optional[Dict] = None) -> FloatTensor:
"""Encode the input image to a tensor representation.
Args:
images (Union[str, ImageType, List[Union[str, ImageType, Tensor]], Tensor]): input image,
can be a str(url), a PIL image, or a tensor, or a list of them
normalize (bool): whether to normalize the output tensor
- image_download_headers (Optional[Dict]): headers for the image download
+ media_download_headers (Optional[Dict]): headers for the image download
Return:
FloatTensor: the encoded image tensor with shape (batch_size, embedding_dim)
"""
- self.image_input_processed: Tensor = self._preprocess_images(images, image_download_headers)
+ self.image_input_processed: Tensor = self._preprocess_images(images, media_download_headers)
with torch.no_grad():
outputs = self.model.encode_image(self.image_input_processed)
@@ -487,8 +487,8 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]],
if is_image:
logger.debug('image')
- image_download_headers = kwargs.get("media_download_headers", dict())
- return self.encode_image(inputs, normalize=normalize, image_download_headers=image_download_headers)
+ media_download_headers = kwargs.get("media_download_headers", dict())
+ return self.encode_image(inputs, normalize=normalize, media_download_headers=media_download_headers)
else:
logger.debug('text')
return self.encode_text(inputs, normalize=normalize)
@@ -573,16 +573,16 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT
return self._convert_output(outputs)
def encode_image(self, images: Union[str, ImageType, List[Union[str, ImageType]]],
- normalize=True, image_download_headers: Optional[dict] = None) -> FloatTensor:
+ normalize=True, media_download_headers: Optional[dict] = None) -> FloatTensor:
if self.visual_model is None:
self.load()
- if image_download_headers is None:
- image_download_headers = dict()
+ if media_download_headers is None:
+ media_download_headers = dict()
# default to batch encoding
if isinstance(images, list):
- image_input = format_and_load_CLIP_images(images, image_download_headers)
+ image_input = format_and_load_CLIP_images(images, media_download_headers)
else:
image_input = [format_and_load_CLIP_image(images, {})]
diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py
index 9e7381379..f040f899b 100644
--- a/src/marqo/tensor_search/tensor_search.py
+++ b/src/marqo/tensor_search/tensor_search.py
@@ -186,7 +186,7 @@ def _add_documents_unstructured(config: Config, add_docs_params: AddDocsParams,
docs=docs,
thread_count=media_download_thread_count,
tensor_fields=tensor_fields_and_multimodal_subfields,
- image_download_headers=add_docs_params.image_download_headers,
+ media_download_headers=add_docs_params.media_download_headers,
model_name=marqo_index.model.name,
normalize_embeddings=marqo_index.normalize_embeddings,
media_field_types_mapping=None,
@@ -709,7 +709,7 @@ def _add_documents_structured(config: Config, add_docs_params: AddDocsParams, ma
docs=docs,
thread_count=media_download_thread_count,
tensor_fields=media_fields,
- image_download_headers=add_docs_params.image_download_headers,
+ media_download_headers=add_docs_params.media_download_headers,
# add non image download headers in the future
model_name=marqo_index.model.name,
normalize_embeddings=marqo_index.normalize_embeddings,
diff --git a/tests/s2_inference/test_image_downloading.py b/tests/s2_inference/test_image_downloading.py
index 89f88200f..29a214024 100644
--- a/tests/s2_inference/test_image_downloading.py
+++ b/tests/s2_inference/test_image_downloading.py
@@ -53,12 +53,12 @@ def test_download_image_from_url_handleDifferentUrlsCorrectly(self):
for url, expected, msg in self.test_cases:
with self.subTest(url=url, expected=expected, msg=msg):
with self.assertRaises(ImageDownloadError) as cm:
- download_image_from_url(image_path=url + ".jpg", image_download_headers={})
+ download_image_from_url(image_path=url + ".jpg", media_download_headers={})
def test_download_image_from_url_handlesUrlRequiringUserAgentHeader(self):
url_requiring_user_agent_header = "https://docs.marqo.ai/2.0.0/Examples/marqo.jpg"
try:
- download_image_from_url(image_path=url_requiring_user_agent_header, image_download_headers={})
+ download_image_from_url(image_path=url_requiring_user_agent_header, media_download_headers={})
except Exception as e:
self.fail(f"Exception was raised when downloading {url_requiring_user_agent_header}: {e}")
@@ -77,7 +77,7 @@ def test_download_image_from_url_mergesDefaultHeadersWithCustomHeaders(self, moc
for (headers, expected_headers, msg) in test_cases:
with self.subTest(headers=headers, expected_headers=expected_headers, msg=msg):
- download_image_from_url('http://example.com/image.jpg', image_download_headers=headers)
+ download_image_from_url('http://example.com/image.jpg', media_download_headers=headers)
mock_curl_instance.setopt.assert_called_with(pycurl.HTTPHEADER, expected_headers)
def test_download_image_from_url_handlesRedirection(self):
@@ -88,5 +88,5 @@ def test_download_image_from_url_handlesRedirection(self):
])
with MockHttpServer(app).run_in_thread() as base_url:
- result = download_image_from_url(f'{base_url}/missing_image.jpg', image_download_headers={})
+ result = download_image_from_url(f'{base_url}/missing_image.jpg', media_download_headers={})
self.assertEqual(result.getvalue(), image_content)
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 417dc3028..4a9238a66 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -578,7 +578,7 @@ def test_imageDownloadWithoutPreprocessor(self):
allocated_docs=[test_doc],
media_repo=media_repo,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
)
@@ -598,7 +598,7 @@ def test_imageDownloadWithPreprocessor(self):
allocated_docs=[test_doc],
media_repo=media_repo,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
preprocessors={'image': lambda x: torch.randn(3, 224, 224)},
device='cpu',
marqo_index_type=IndexType.Unstructured,
@@ -620,7 +620,7 @@ def run():
{"Title": "frog", "Desc": "blah"}, {"Title": "Dog", "Loc": "https://google.com/my_dog.png"}],
media_repo=media_repo,
tensor_fields=['Title', 'Desc', 'Loc'],
- image_download_headers={},
+ media_download_headers={},
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
)
@@ -709,7 +709,7 @@ def test_threaded_download_images_non_tensor_field(self):
allocated_docs=docs,
media_repo=media_repo,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
)
@@ -761,7 +761,7 @@ def test_download_images_non_tensor_field(self):
docs=docs,
thread_count=20,
tensor_fields=['field_1', 'field_2'],
- image_download_headers={},
+ media_download_headers={},
model_name="ViT-B/32",
normalize_embeddings=True,
model_properties=model_properties,
diff --git a/tests/tensor_search/integ_tests/test_embed.py b/tests/tensor_search/integ_tests/test_embed.py
index a971ced5d..77ba44f98 100644
--- a/tests/tensor_search/integ_tests/test_embed.py
+++ b/tests/tensor_search/integ_tests/test_embed.py
@@ -523,9 +523,9 @@ def run():
self.assertEqual(embed_res["content"], [image_url])
self.assertTrue(np.allclose(embed_res["embeddings"][0], search_query_embedding))
- def test_embed_with_image_download_headers_and_model_auth(self):
+ def test_embed_with_media_download_headers_and_model_auth(self):
"""
- Ensure that vectorise is called with the correct image_download_headers and model_auth
+ Ensure that vectorise is called with the correct media_download_headers and model_auth
when using the embed endpoint.
"""
for index in [self.unstructured_default_image_index, self.structured_default_image_index]:
@@ -537,7 +537,7 @@ def pass_through_vectorise(*arg, **kwargs):
via mock
Set image download headers and model auth to None so there's no error out.
"""
- kwargs["image_download_headers"] = None
+ kwargs["media_download_headers"] = None
kwargs["model_auth"] = None
return vectorise(*arg, **kwargs)
@@ -549,7 +549,7 @@ def run():
marqo_config=self.config, index_name=index.name,
embedding_request=EmbedRequest(
content=[image_url],
- image_download_headers={"Authorization": "my secret key"},
+ media_download_headers={"Authorization": "my secret key"},
modelAuth=ModelAuth(s3=S3Auth(
aws_access_key_id='12345',
aws_secret_access_key='this-is-a-secret'))
@@ -564,7 +564,7 @@ def run():
self.assertEqual(len(call_args), 1)
vectorise_kwargs = call_args[0].kwargs
- self.assertEqual(vectorise_kwargs["image_download_headers"], {"Authorization": "my secret key"})
+ self.assertEqual(vectorise_kwargs["media_download_headers"], {"Authorization": "my secret key"})
self.assertEqual(vectorise_kwargs["model_auth"], ModelAuth(s3=S3Auth(
aws_access_key_id='12345',
aws_secret_access_key='this-is-a-secret')))
diff --git a/tests/tensor_search/test_add_documents_use_existing_tensors.py b/tests/tensor_search/test_add_documents_use_existing_tensors.py
index cd9ea8e88..b1febcfc3 100644
--- a/tests/tensor_search/test_add_documents_use_existing_tensors.py
+++ b/tests/tensor_search/test_add_documents_use_existing_tensors.py
@@ -829,7 +829,7 @@ def run():
vectorised_content = [call_kwargs['content'] for call_args, call_kwargs
in mock_vectorise.call_args_list]
- artefact_pil_image = load_image_from_path(artefact_hippo_img, image_download_headers={})
+ artefact_pil_image = load_image_from_path(artefact_hippo_img, media_download_headers={})
expected_to_be_vectorised = [
["this is the updated 1st sentence.", "This is my second"],
["this is a brand new sentence.", "Yes it is"],
diff --git a/tests/tensor_search/test_api_utils.py b/tests/tensor_search/test_api_utils.py
index 437d81654..acb040651 100644
--- a/tests/tensor_search/test_api_utils.py
+++ b/tests/tensor_search/test_api_utils.py
@@ -98,13 +98,13 @@ def test_add_docs_params_orchestrator(self):
# Query parameters should be parsed as default values
non_tensor_fields = []
use_existing_tensors = False
- image_download_headers = dict()
+ media_download_headers = dict()
model_auth = None
mappings = dict()
# Call the function with the arguments
result = add_docs_params_orchestrator(index_name, body, device, non_tensor_fields, mappings,
- model_auth, image_download_headers, use_existing_tensors)
+ model_auth, media_download_headers, use_existing_tensors)
# Assert that the result is as expected
assert isinstance(result, AddDocsParams)
@@ -114,7 +114,7 @@ def test_add_docs_params_orchestrator(self):
assert result.non_tensor_fields == ["field1"]
assert result.use_existing_tensors == True
assert result.docs == [{"test": "doc"}]
- assert result.image_download_headers == {"header1": "value1"}
+ assert result.media_download_headers == {"header1": "value1"}
def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
# Set up the arguments for the function
@@ -126,14 +126,14 @@ def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
device = "test-device"
non_tensor_fields = ["field1"]
use_existing_tensors = True
- image_download_headers = {"header1": "value1"}
+ media_download_headers = {"header1": "value1"}
model_auth = model_auth
mappings = {"map1": "value1"}
auto_refresh = True
# Call the function with the arguments
result = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
- model_auth, image_download_headers, use_existing_tensors)
+ model_auth, media_download_headers, use_existing_tensors)
# Assert that the result is as expected
assert isinstance(result, AddDocsParams)
@@ -143,7 +143,7 @@ def test_add_docs_params_orchestrator_deprecated_query_parameters(self):
assert result.non_tensor_fields == ["field1"]
assert result.use_existing_tensors == True
assert result.docs == [{"test": "doc"}]
- assert result.image_download_headers == {"header1": "value1"}
+ assert result.media_download_headers == {"header1": "value1"}
def test_add_docs_params_orchestrator_error(self):
# Test the case where the function should raise an error due to invalid input
@@ -155,7 +155,7 @@ def test_add_docs_params_orchestrator_error(self):
device = "test-device"
non_tensor_fields = ["field1"]
use_existing_tensors = True
- image_download_headers = {"header1": "value1"}
+ media_download_headers = {"header1": "value1"}
model_auth = model_auth
mappings = {"map1": "value1"}
auto_refresh = True
@@ -163,7 +163,7 @@ def test_add_docs_params_orchestrator_error(self):
# Use pytest.raises to check for the error
try:
_ = add_docs_params_orchestrator(index_name, body, device, auto_refresh, non_tensor_fields, mappings,
- model_auth, image_download_headers, use_existing_tensors)
+ model_auth, media_download_headers, use_existing_tensors)
except InternalError as e:
self.assertIn("Unexpected request body type", str(e))
@@ -181,7 +181,7 @@ def test_add_docs_params_orchestrator_deprecated_query_parameters_error(self):
mappings={"map1": "value1"})
params = {"non_tensor_fields": ["what"], "use_existing_tensors": True,
- "image_download_headers": {"header2": "value2"}, "model_auth": model_auth,
+ "media_download_headers": {"header2": "value2"}, "model_auth": model_auth,
"mappings": {"map2": "value2"}}
for param, value in params.items():
diff --git a/tests/tensor_search/test_image_download_headers.py b/tests/tensor_search/test_image_download_headers.py
index ea692be9e..04c0ef0a7 100644
--- a/tests/tensor_search/test_image_download_headers.py
+++ b/tests/tensor_search/test_image_download_headers.py
@@ -62,11 +62,11 @@ def test_img_download_search(self):
tensor_search.create_vector_index(
config=self.config, index_name=self.index_name_1, index_settings=self.image_index_settings()
)
- image_download_headers = {"Authorization": "some secret key blah"}
+ media_download_headers = {"Authorization": "some secret key blah"}
self.add_documents(config=self.config, add_docs_params=AddDocsParams(
index_name=self.index_name_1, docs=[
{"_id": "1", "image": self.real_img_url}],
- auto_refresh=True, image_download_headers=image_download_headers, device="cpu"))
+ auto_refresh=True, media_download_headers=media_download_headers, device="cpu"))
def pass_through_requests_get(url, *args, **kwargs):
return requests_get(url, *args, **kwargs)
@@ -80,11 +80,11 @@ def pass_through_requests_get(url, *args, **kwargs):
# Perform a vector search
search_res = tensor_search._vector_text_search(
config=self.config, index_name=self.index_name_1,
- result_count=1, query=self.real_img_url, image_download_headers=image_download_headers, device="cpu"
+ result_count=1, query=self.real_img_url, media_download_headers=media_download_headers, device="cpu"
)
# Check if the image URL was called at least once with the correct headers
image_url_called = any(
- call_args[0] == self.real_img_url and call_kwargs.get('headers', None) == image_download_headers
+ call_args[0] == self.real_img_url and call_kwargs.get('headers', None) == media_download_headers
for call_args, call_kwargs in mock_get.call_args_list
)
assert image_url_called, "Image URL not called with the correct headers"
@@ -102,18 +102,18 @@ def pass_through_load_image_from_path(*arg, **kwargs):
@unittest.mock.patch("marqo.s2_inference.clip_utils.load_image_from_path", mock_load_image_from_path)
def run():
- image_download_headers = {"Authorization": "some secret key blah"}
+ media_download_headers = {"Authorization": "some secret key blah"}
# Add a document with an image URL
self.add_documents(config=self.config, add_docs_params=AddDocsParams(
index_name=self.index_name_1, docs=[
{"_id": "1", "image": self.real_img_url}
- ], auto_refresh=True, image_download_headers=image_download_headers, device="cpu"
+ ], auto_refresh=True, media_download_headers=media_download_headers, device="cpu"
))
# Check if load_image_from_path was called with the correct headers
assert len(mock_load_image_from_path.call_args_list) == 1
call_args, call_kwargs = mock_load_image_from_path.call_args_list[0]
- assert image_download_headers in call_args
+ assert media_download_headers in call_args
return True
assert run() is True
@@ -123,14 +123,14 @@ def test_img_download_bulk_search(self):
tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,
index_settings=self.image_index_settings())
test_image_url = self.real_img_url
- image_download_headers = {"Authorization": "some secret key blah"}
+ media_download_headers = {"Authorization": "some secret key blah"}
def pass_through_load_image_from_path(*args, **kwargs):
return load_image_from_path(*args, **kwargs)
def pass_through_requests_get(url, *args, **kwargs):
if url == test_image_url:
- assert kwargs.get('headers', None) == image_download_headers
+ assert kwargs.get('headers', None) == media_download_headers
return requests_get(url, *args, **kwargs)
# Mock the load_image_from_path function
@@ -144,7 +144,7 @@ def pass_through_requests_get(url, *args, **kwargs):
"_id": "1",
"image": test_image_url,
}],
- auto_refresh=True, image_download_headers=image_download_headers, device="cpu"))
+ auto_refresh=True, media_download_headers=media_download_headers, device="cpu"))
# Set up the mock GET
mock_get = unittest.mock.MagicMock()
@@ -155,13 +155,13 @@ def pass_through_requests_get(url, *args, **kwargs):
bulk_search_query = BulkSearchQuery(queries=[{
"index": self.index_name_1,
"q": self.real_img_url,
- "image_download_headers": image_download_headers
+ "media_download_headers": media_download_headers
}])
resp = tensor_search.bulk_search(marqo_config=self.config, query=bulk_search_query)
# Check if the image URL was called at least once with the correct headers
image_url_called = any(
- call_args[0] == test_image_url and call_kwargs.get('headers', None) == image_download_headers
+ call_args[0] == test_image_url and call_kwargs.get('headers', None) == media_download_headers
for call_args, call_kwargs in mock_get.call_args_list
)
assert image_url_called, "Image URL not called with the correct headers"
diff --git a/tests/tensor_search/test_search.py b/tests/tensor_search/test_search.py
index c44848c12..0a0fbdc15 100644
--- a/tests/tensor_search/test_search.py
+++ b/tests/tensor_search/test_search.py
@@ -1136,7 +1136,7 @@ def run() -> typing.List[float]:
weighted_vectors = []
for q, weight in multi_query.items():
vec = vectorise(model_name="ViT-B/16", content=[q, ],
- image_download_headers=None, normalize_embeddings=True,
+ media_download_headers=None, normalize_embeddings=True,
device="cpu")[0]
weighted_vectors.append(np.asarray(vec) * weight)
From 2e35bacca465983bf2ee4ef735f0cd35af9af3e4 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 19:24:13 +1100
Subject: [PATCH 15/29] Fix tests
---
.../unstructured_add_document_handler.py | 2 +-
src/marqo/s2_inference/multimodal_model_load.py | 1 -
tests/s2_inference/test_vectorise.py | 3 ++-
tests/tensor_search/integ_tests/test_add_documents_combined.py | 3 ++-
tests/tensor_search/integ_tests/test_embed.py | 2 +-
5 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py b/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
index c9f89ccd9..7915455aa 100644
--- a/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
+++ b/src/marqo/core/unstructured_vespa_index/unstructured_add_document_handler.py
@@ -65,7 +65,7 @@ def _validate_doc(self, doc):
def _handle_field(self, marqo_doc, field_name, field_content):
self._validate_field(field_name, field_content)
- text_field_type = self._infer_field_type(field_content)
+ text_field_type = self._infer_field_type(field_content, self.add_docs_params.media_download_headers)
content = self.tensor_fields_container.collect(marqo_doc[MARQO_DOC_ID], field_name,
field_content, text_field_type)
marqo_doc[field_name] = content
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index 2dc6da1bd..ad4bb2506 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -170,7 +170,6 @@ def infer_modality(content: Union[str, List[str], bytes], media_download_headers
return Modality.VIDEO
elif extension in ['mp3', 'wav', 'ogg']:
return Modality.AUDIO
-
if validate_url(encoded_url):
# Use context manager to handle content sample
try:
diff --git a/tests/s2_inference/test_vectorise.py b/tests/s2_inference/test_vectorise.py
index 6e51446b0..5ccd1bde4 100644
--- a/tests/s2_inference/test_vectorise.py
+++ b/tests/s2_inference/test_vectorise.py
@@ -240,7 +240,8 @@ def test_vectorise_single_content_item(self):
result = s2_inference.vectorise(model_name='mock_model', content=single_content,
model_properties=self.mock_model_props, device="cpu")
- self.mock_model.encode.assert_called_once_with(single_content, normalize=True, modality=Modality.TEXT)
+ self.mock_model.encode.assert_called_once_with(single_content, normalize=True, modality=Modality.TEXT,
+ media_download_headers=None)
self.assertIsInstance(result, list)
self.assertEqual(len(result), 1)
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 4a9238a66..2729ad830 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -1128,7 +1128,8 @@ def test_add_private_images_proper_error_returned(self):
def test_add_private_images_success(self):
"""Test to ensure that private images can be downloaded with proper headers"""
- test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
+ # test_indexes = [self.structured_marqo_index_name, self.unstructured_marqo_index_name]
+ test_indexes = [self.unstructured_marqo_index_name, ]
documents = [
{
"image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
diff --git a/tests/tensor_search/integ_tests/test_embed.py b/tests/tensor_search/integ_tests/test_embed.py
index 77ba44f98..1e393ad69 100644
--- a/tests/tensor_search/integ_tests/test_embed.py
+++ b/tests/tensor_search/integ_tests/test_embed.py
@@ -549,7 +549,7 @@ def run():
marqo_config=self.config, index_name=index.name,
embedding_request=EmbedRequest(
content=[image_url],
- media_download_headers={"Authorization": "my secret key"},
+ mediaDownloadHeaders={"Authorization": "my secret key"},
modelAuth=ModelAuth(s3=S3Auth(
aws_access_key_id='12345',
aws_secret_access_key='this-is-a-secret'))
From 2cc4622332be2da095f3f7d16aba38ffbf0f90b4 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Thu, 24 Oct 2024 22:12:43 +1100
Subject: [PATCH 16/29] Add language bind modality tests
---
.../core/vespa_index/add_documents_handler.py | 1 -
tests/marqo_test.py | 15 +++
.../test_add_documents_combined.py | 127 ++++++++++++++++--
.../integ_tests/test_search_combined.py | 102 +++++++++++++-
4 files changed, 230 insertions(+), 15 deletions(-)
diff --git a/src/marqo/core/vespa_index/add_documents_handler.py b/src/marqo/core/vespa_index/add_documents_handler.py
index b181d35b7..8133abd4d 100644
--- a/src/marqo/core/vespa_index/add_documents_handler.py
+++ b/src/marqo/core/vespa_index/add_documents_handler.py
@@ -421,4 +421,3 @@ def _field_type_chunker_map(self, media_repo):
FieldType.VideoPointer: AudioVideoChunker(media_repo=media_repo),
}
return chunkers
-
diff --git a/tests/marqo_test.py b/tests/marqo_test.py
index 8d66c2a86..25edff9d7 100644
--- a/tests/marqo_test.py
+++ b/tests/marqo_test.py
@@ -36,6 +36,21 @@ class TestImageUrls(str, Enum):
HIPPO_STATUE = 'https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue_small.png'
+class TestAudioUrls(str, Enum):
+ __test__ = False
+ AUDIO1 = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-100032-A-0.wav"
+ AUDIO2 = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-115545-C-48.wav"
+ AUDIO3 = "https://marqo-ecs-50-audio-test-dataset.s3.us-east-1.amazonaws.com/audios/1-119125-A-45.wav"
+
+
+class TestVideoUrls(str, Enum):
+ __test__ = False
+ VIDEO1 = "https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/--_S9IDQPLg_000135_000145.mp4"
+ VIDEO2 = "https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4"
+ VIDEO3 = "https://marqo-k400-video-test-dataset.s3.us-east-1.amazonaws.com/videos/--mI_-gaZLk_000018_000028.mp4"
+
+
+
class MarqoTestCase(unittest.TestCase):
indexes = []
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 8a9c7d6da..e7cc77ff2 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -1,25 +1,18 @@
import os
import unittest.mock
+import unittest.mock
import uuid
from unittest import mock
from unittest.mock import patch
import PIL
-import numpy as np
-
import numpy as np
import pytest
-
-
-import PIL
import requests
import torch
-from more_itertools import flatten
-from numpy.ma.core import subtract
from torch import Tensor
-import unittest.mock
-
+from marqo.core.models.add_docs_params import AddDocsParams, BatchVectorisationMode
from marqo.core.models.marqo_get_documents_by_id_response import MarqoGetDocumentsByIdsResponse
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import FieldRequest
@@ -28,10 +21,7 @@
from marqo.tensor_search import add_docs
from marqo.tensor_search import streaming_media_processor
from marqo.tensor_search import tensor_search
-from marqo.core.models.add_docs_params import AddDocsParams, BatchVectorisationMode
-from tests.marqo_test import MarqoTestCase, TestImageUrls
-from marqo.s2_inference.multimodal_model_load import infer_modality
-from marqo.tensor_search import streaming_media_processor
+from tests.marqo_test import MarqoTestCase, TestImageUrls, TestAudioUrls, TestVideoUrls
class TestAddDocumentsCombined(MarqoTestCase):
@@ -1174,4 +1164,115 @@ def test_add_private_images_success(self):
mappings=mappings
)
)
+ self.assertFalse(res.errors)
+
+
+
+
+@pytest.mark.largemodel
+class TestLanguageBindModelAddDocumentCombined(MarqoTestCase):
+ """A class to test the add_documents with the LanguageBind model."""
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ super().setUpClass()
+
+ structured_language_bind_index = cls.structured_marqo_index_request(
+ name="structured_image_index" + str(uuid.uuid4()).replace('-', ''),
+ fields=[
+ FieldRequest(name="text_field_1", type=FieldType.Text,
+ features=[FieldFeature.Filter, FieldFeature.LexicalSearch]),
+ FieldRequest(name="image_field_1", type=FieldType.ImagePointer),
+ FieldRequest(name="audio_field_1", type=FieldType.AudioPointer),
+ FieldRequest(name="video_field_1", type=FieldType.VideoPointer),
+ FieldRequest(
+ name="multimodal_field",
+ type=FieldType.MultimodalCombination,
+ dependent_fields={
+ "image_field_1": 1.0,
+ "text_field_1": 1.0,
+ "audio_field_1": 1.0,
+ "video_field_1": 1.0,
+ }
+ )
+ ],
+ model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
+ tensor_fields=["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"],
+ )
+
+ unstructured_language_bind_index = cls.unstructured_marqo_index_request(
+ name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''),
+ model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
+ treat_urls_and_pointers_as_images=True,
+ treat_urls_and_pointers_as_media=True
+ )
+
+ cls.indexes = cls.create_indexes([structured_language_bind_index, unstructured_language_bind_index])
+
+ cls.structured_language_bind_index_name = structured_language_bind_index.name
+ cls.unstructured_language_bind_index_name = unstructured_language_bind_index.name
+
+ s2_inference.clear_loaded_models()
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ super().tearDownClass()
+ s2_inference.clear_loaded_models()
+
+ def test_language_bind_model_can_add_all_media_modalities(self):
+ """Test to ensure that the LanguageBind model can add all media types to the index"""
+ documents = [
+ {
+ "text_field_1": "This is a test text",
+ "image_field_1": TestImageUrls.IMAGE1.value,
+ "audio_field_1": TestAudioUrls.AUDIO1.value,
+ "video_field_1": TestVideoUrls.VIDEO1.value,
+ "_id": "1"
+ }
+ ]
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ tensor_fields = ["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"] \
+ if index_name == self.unstructured_language_bind_index_name else None
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields
+ )
+ )
+ self.assertFalse(res.errors)
+
+ def test_language_bind_model_can_add_all_private_media_modalities(self):
+ documents = [
+ { # With extensions
+ "text_field_1": "This is a test text",
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
+ "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
+ "_id": "1"
+ },
+ {
+ # No extensions
+ "text_field_1": "This is a test text",
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
+ "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
+ "_id": "1"
+ }
+ ]
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ tensor_fields = ["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"] \
+ if index_name == self.unstructured_language_bind_index_name else None
+ with self.subTest(index_name):
+ res = tensor_search.add_documents(
+ self.config,
+ add_docs_params=AddDocsParams(
+ docs=documents,
+ index_name=index_name,
+ tensor_fields=tensor_fields,
+ media_download_headers={"marqo_media_header": "media_header_test_key"}
+ )
+ )
self.assertFalse(res.errors)
\ No newline at end of file
diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py
index 599a63640..285e04ef0 100644
--- a/tests/tensor_search/integ_tests/test_search_combined.py
+++ b/tests/tensor_search/integ_tests/test_search_combined.py
@@ -18,7 +18,7 @@
from marqo.tensor_search import tensor_search
from marqo.tensor_search.enums import SearchMethod
from marqo.tensor_search.models.api_models import SearchQuery
-from tests.marqo_test import MarqoTestCase, TestImageUrls
+from tests.marqo_test import MarqoTestCase, TestImageUrls, TestAudioUrls, TestVideoUrls
class TestSearch(MarqoTestCase):
@@ -1057,3 +1057,103 @@ def test_lexical_search_DoesNotErrorWithEscapedQuotes(self):
)
self.assertEqual(len(expected_ids), len(res['hits']))
self.assertEqual(set(expected_ids), {hit['_id'] for hit in res['hits']})
+
+
+@pytest.mark.largemodel
+class TestLanguageBindModelAddDocumentCombined(MarqoTestCase):
+ """A class to test the search with the LanguageBind model."""
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ super().setUpClass()
+
+ structured_language_bind_index = cls.structured_marqo_index_request(
+ name="structured_image_index" + str(uuid.uuid4()).replace('-', ''),
+ fields=[
+ FieldRequest(name="text_field_1", type=FieldType.Text,
+ features=[FieldFeature.Filter, FieldFeature.LexicalSearch]),
+ FieldRequest(name="image_field_1", type=FieldType.ImagePointer),
+ FieldRequest(name="audio_field_1", type=FieldType.AudioPointer),
+ FieldRequest(name="video_field_1", type=FieldType.VideoPointer),
+ FieldRequest(
+ name="multimodal_field",
+ type=FieldType.MultimodalCombination,
+ dependent_fields={
+ "image_field_1": 1.0,
+ "text_field_1": 1.0,
+ "audio_field_1": 1.0,
+ "video_field_1": 1.0,
+ }
+ )
+ ],
+ model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
+ tensor_fields=["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"],
+ )
+
+ unstructured_language_bind_index = cls.unstructured_marqo_index_request(
+ name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''),
+ model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
+ treat_urls_and_pointers_as_images=True,
+ treat_urls_and_pointers_as_media=True
+ )
+
+ cls.indexes = cls.create_indexes([structured_language_bind_index, unstructured_language_bind_index])
+
+ cls.structured_language_bind_index_name = structured_language_bind_index.name
+ cls.unstructured_language_bind_index_name = unstructured_language_bind_index.name
+
+ s2_inference.clear_loaded_models()
+
+ @classmethod
+ def tearDownClass(cls) -> None:
+ super().tearDownClass()
+ s2_inference.clear_loaded_models()
+
+ def test_language_bind_model_can_search_all_media_modalities(self):
+ """Test to ensure that the LanguageBind model can search all media types to the index"""
+ queries = [
+ "This is a test text",
+ TestImageUrls.IMAGE1.value,
+ TestAudioUrls.AUDIO1.value,
+ TestVideoUrls.VIDEO1.value,
+ {
+ "This is a test text": 1,
+ TestImageUrls.IMAGE1.value: 1,
+ TestAudioUrls.AUDIO1.value: 1,
+ TestVideoUrls.VIDEO1.value: 1
+ }
+ ]
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ for query in queries:
+ with self.subTest(index_name):
+ _ = tensor_search.search(
+ config = self.config,
+ index_name=index_name,
+ text=query,
+ search_method=SearchMethod.LEXICAL
+ )
+
+ def test_language_bind_model_can_search_all_private_media_modalities(self):
+ """A test to ensure that the LanguageBind model can search all private media types to the index"""
+ queries = [
+ "This is a test text",
+ "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
+ "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
+ "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
+ {
+ "This is a test text": 1,
+ "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png": 1,
+ "https://d2k91vq0avo7lq.cloudfront.net/bark.wav": 1,
+ "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4": 1
+ }
+ ]
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ for query in queries:
+ with self.subTest(index_name):
+ _ = tensor_search.search(
+ config = self.config,
+ index_name=index_name,
+ text=query,
+ search_method=SearchMethod.LEXICAL,
+ media_download_headers={"marqo_media_header": "media_header_test_key"}
+ )
\ No newline at end of file
From bc35efbd1aec29043970185f24046b11812870b0 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 10:15:59 +1100
Subject: [PATCH 17/29] Fix tests
---
tests/tensor_search/integ_tests/test_search_combined.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py
index 285e04ef0..f7e826860 100644
--- a/tests/tensor_search/integ_tests/test_search_combined.py
+++ b/tests/tensor_search/integ_tests/test_search_combined.py
@@ -1060,7 +1060,7 @@ def test_lexical_search_DoesNotErrorWithEscapedQuotes(self):
@pytest.mark.largemodel
-class TestLanguageBindModelAddDocumentCombined(MarqoTestCase):
+class TestLanguageBindModelSearchCombined(MarqoTestCase):
"""A class to test the search with the LanguageBind model."""
@classmethod
@@ -1130,7 +1130,7 @@ def test_language_bind_model_can_search_all_media_modalities(self):
config = self.config,
index_name=index_name,
text=query,
- search_method=SearchMethod.LEXICAL
+ search_method=SearchMethod.TENSOR
)
def test_language_bind_model_can_search_all_private_media_modalities(self):
@@ -1154,6 +1154,6 @@ def test_language_bind_model_can_search_all_private_media_modalities(self):
config = self.config,
index_name=index_name,
text=query,
- search_method=SearchMethod.LEXICAL,
+ search_method=SearchMethod.TENSOR,
media_download_headers={"marqo_media_header": "media_header_test_key"}
)
\ No newline at end of file
From 674bb17a3f7486750ed3d07855e4c669a34b4223 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 10:56:02 +1100
Subject: [PATCH 18/29] Fix headers for media
---
src/marqo/tensor_search/add_docs.py | 21 ++++++++++++-------
.../streaming_media_processor.py | 16 ++++++++++----
2 files changed, 26 insertions(+), 11 deletions(-)
diff --git a/src/marqo/tensor_search/add_docs.py b/src/marqo/tensor_search/add_docs.py
index 1643c432d..9906075a7 100644
--- a/src/marqo/tensor_search/add_docs.py
+++ b/src/marqo/tensor_search/add_docs.py
@@ -165,9 +165,12 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
try:
- processed_chunks = download_and_chunk_media(doc[field], device, media_download_headers, inferred_modality,
- marqo_index_type, marqo_index_model, preprocessors,
- audio_preprocessing, video_preprocessing)
+ processed_chunks = download_and_chunk_media(
+ url=doc[field], device=device, modality=inferred_modality,
+ marqo_index_type=marqo_index_type, marqo_index_model=marqo_index_model,
+ preprocessors=preprocessors, audio_preprocessing=audio_preprocessing,
+ video_preprocessing=video_preprocessing, media_download_headers=media_download_headers
+ )
media_repo[doc[field]] = processed_chunks
except (ffmpeg.Error, S2InferenceError) as e:
logger.error(f"Error processing {inferred_modality} file: {str(e)}")
@@ -197,13 +200,17 @@ def threaded_download_and_preprocess_content(allocated_docs: List[dict],
continue
-def download_and_chunk_media(url: str, device: str, headers: dict, modality: Modality, marqo_index_type: IndexType, marqo_index_model: Model,
+def download_and_chunk_media(url: str, device: str, modality: Modality, marqo_index_type: IndexType, marqo_index_model: Model,
preprocessors: Preprocessors, audio_preprocessing: AudioPreProcessing = None,
- video_preprocessing: VideoPreProcessing = None) -> List[Dict[str, torch.Tensor]]:
+ video_preprocessing: VideoPreProcessing = None,
+ media_download_headers: Optional[Dict] = None) -> List[Dict[str, torch.Tensor]]:
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB in bytes
- processor = StreamingMediaProcessor(url, device, headers, modality, marqo_index_type, marqo_index_model, preprocessors,
- audio_preprocessing, video_preprocessing)
+ processor = StreamingMediaProcessor(
+ url=url, device=device, modality=modality, marqo_index_type=marqo_index_type, marqo_index_model=marqo_index_model,
+ preprocessors=preprocessors, audio_preprocessing=audio_preprocessing, video_preprocessing=video_preprocessing,
+ media_download_headers=media_download_headers
+ )
if processor.total_size > MAX_FILE_SIZE:
raise ValueError(
diff --git a/src/marqo/tensor_search/streaming_media_processor.py b/src/marqo/tensor_search/streaming_media_processor.py
index 72b75de3c..a972739d7 100644
--- a/src/marqo/tensor_search/streaming_media_processor.py
+++ b/src/marqo/tensor_search/streaming_media_processor.py
@@ -18,12 +18,11 @@
class StreamingMediaProcessor:
- def __init__(self, url: str, device: str, headers: Dict[str, str], modality: Modality, marqo_index_type: IndexType,
+ def __init__(self, url: str, device: str, modality: Modality, marqo_index_type: IndexType,
marqo_index_model: Model, preprocessors: Preprocessors, audio_preprocessing: AudioPreProcessing = None,
- video_preprocessing: VideoPreProcessing = None):
+ video_preprocessing: VideoPreProcessing = None, media_download_headers: Optional[Dict[str, str] ]= None):
self.url = url
self.device = device
- self.headers = headers
self.modality = modality
self.marqo_index_type = marqo_index_type
self.marqo_index_model = marqo_index_model
@@ -33,6 +32,10 @@ def __init__(self, url: str, device: str, headers: Dict[str, str], modality: Mod
self.preprocessor = self.preprocessors[modality]
self.total_size, self.duration = self._fetch_file_metadata()
+ if media_download_headers is None:
+ media_download_headers = {}
+ self.media_download_headers = media_download_headers
+
self._set_split_parameters(modality)
self._log_initialization_details()
@@ -67,6 +70,8 @@ def _fetch_file_metadata(self):
'probesize': '256K' # Probe only the first 256KB
}
+ probe_options.update(self.media_download_headers)
+
probe = ffmpeg.probe(self.url, **probe_options)
size = int(probe['format'].get('size', 0))
@@ -105,7 +110,10 @@ def process_media(self) -> List[Dict[str, torch.Tensor]]:
try:
# Use ffmpeg-python to process the chunk
- stream = ffmpeg.input(self.url, ss=chunk_start, t=chunk_end - chunk_start)
+ stream = ffmpeg.input(
+ self.url, ss=chunk_start, t=chunk_end - chunk_start,
+ headers=self.media_download_headers
+ )
if self.modality == Modality.VIDEO:
stream = ffmpeg.output(stream, output_file, vcodec='libx264', acodec='aac', **{'f': 'mp4'})
From 27da146e516d00d389993044e8425221f87fbc74 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 10:58:35 +1100
Subject: [PATCH 19/29] Fix headers for media
---
src/marqo/tensor_search/streaming_media_processor.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/marqo/tensor_search/streaming_media_processor.py b/src/marqo/tensor_search/streaming_media_processor.py
index a972739d7..2aa1f4c7f 100644
--- a/src/marqo/tensor_search/streaming_media_processor.py
+++ b/src/marqo/tensor_search/streaming_media_processor.py
@@ -30,12 +30,12 @@ def __init__(self, url: str, device: str, modality: Modality, marqo_index_type:
self.video_preprocessing = video_preprocessing
self.preprocessors = preprocessors
self.preprocessor = self.preprocessors[modality]
- self.total_size, self.duration = self._fetch_file_metadata()
-
if media_download_headers is None:
media_download_headers = {}
self.media_download_headers = media_download_headers
+ self.total_size, self.duration = self._fetch_file_metadata()
+
self._set_split_parameters(modality)
self._log_initialization_details()
From c37fe9e1e188f8d97335ceb079ba23ad839e15da Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 10:59:56 +1100
Subject: [PATCH 20/29] Fix headers for media
---
src/marqo/tensor_search/streaming_media_processor.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/src/marqo/tensor_search/streaming_media_processor.py b/src/marqo/tensor_search/streaming_media_processor.py
index 2aa1f4c7f..786ccbd52 100644
--- a/src/marqo/tensor_search/streaming_media_processor.py
+++ b/src/marqo/tensor_search/streaming_media_processor.py
@@ -20,7 +20,7 @@
class StreamingMediaProcessor:
def __init__(self, url: str, device: str, modality: Modality, marqo_index_type: IndexType,
marqo_index_model: Model, preprocessors: Preprocessors, audio_preprocessing: AudioPreProcessing = None,
- video_preprocessing: VideoPreProcessing = None, media_download_headers: Optional[Dict[str, str] ]= None):
+ video_preprocessing: VideoPreProcessing = None, media_download_headers: Optional[Dict[str, str]]= None):
self.url = url
self.device = device
self.modality = modality
@@ -67,11 +67,10 @@ def _fetch_file_metadata(self):
'v': 'error',
'show_entries': 'format=size,duration',
'of': 'json',
- 'probesize': '256K' # Probe only the first 256KB
+ 'probesize': '256K', # Probe only the first 256KB
+ 'headers': self.media_download_headers
}
- probe_options.update(self.media_download_headers)
-
probe = ffmpeg.probe(self.url, **probe_options)
size = int(probe['format'].get('size', 0))
From a35b1c90d9a760b9379109f126d42f62683b6be3 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 11:14:35 +1100
Subject: [PATCH 21/29] Fix media download headers for video and audio
---
.../streaming_media_processor.py | 39 +++++++++++++++----
1 file changed, 31 insertions(+), 8 deletions(-)
diff --git a/src/marqo/tensor_search/streaming_media_processor.py b/src/marqo/tensor_search/streaming_media_processor.py
index 786ccbd52..56d285637 100644
--- a/src/marqo/tensor_search/streaming_media_processor.py
+++ b/src/marqo/tensor_search/streaming_media_processor.py
@@ -15,6 +15,7 @@
from marqo.core.models.marqo_index import *
from marqo.s2_inference.multimodal_model_load import Modality
from marqo.tensor_search.models.preprocessors_model import Preprocessors
+from marqo.core.exceptions import InternalError
class StreamingMediaProcessor:
@@ -30,9 +31,7 @@ def __init__(self, url: str, device: str, modality: Modality, marqo_index_type:
self.video_preprocessing = video_preprocessing
self.preprocessors = preprocessors
self.preprocessor = self.preprocessors[modality]
- if media_download_headers is None:
- media_download_headers = {}
- self.media_download_headers = media_download_headers
+ self.media_download_headers = self._convert_headers_to_cli_format(media_download_headers)
self.total_size, self.duration = self._fetch_file_metadata()
@@ -59,6 +58,25 @@ def _log_initialization_details(self):
# print(f"from StreamingMediaProcessor, self.duration: {self.duration}")
pass
+ def _convert_headers_to_cli_format(self, raw_media_download_headers: Optional[Dict] = None) -> str:
+ """
+ A helper function to convert the media download headers into a format that can be passed to ffmpeg in
+ subprocess calls.
+
+ Examples:
+ If the headers are {"key1": "value1", "key2": "value2"}, the function will return a string
+ "key1: value1\r\nkey2: value2"
+
+ Returns:
+ str: The headers in the required format. An empty string if no headers or None are provided.
+ """
+ if raw_media_download_headers is None or raw_media_download_headers == {}:
+ return ""
+ elif not isinstance(raw_media_download_headers, dict):
+ raise InternalError("media_download_headers should be a dictionary")
+ return "\r\n".join([f"{key}: {value}" for key, value in raw_media_download_headers.items()])
+
+
def _fetch_file_metadata(self):
start_time = time.time()
@@ -68,9 +86,11 @@ def _fetch_file_metadata(self):
'show_entries': 'format=size,duration',
'of': 'json',
'probesize': '256K', # Probe only the first 256KB
- 'headers': self.media_download_headers
}
+ if self.media_download_headers:
+ probe_options['headers'] = self.media_download_headers
+
probe = ffmpeg.probe(self.url, **probe_options)
size = int(probe['format'].get('size', 0))
@@ -109,10 +129,13 @@ def process_media(self) -> List[Dict[str, torch.Tensor]]:
try:
# Use ffmpeg-python to process the chunk
- stream = ffmpeg.input(
- self.url, ss=chunk_start, t=chunk_end - chunk_start,
- headers=self.media_download_headers
- )
+ if self.media_download_headers:
+ stream = ffmpeg.input(
+ self.url, ss=chunk_start, t=chunk_end - chunk_start,
+ headers=self.media_download_headers
+ )
+ else:
+ stream = ffmpeg.input(self.url, ss=chunk_start, t=chunk_end - chunk_start)
if self.modality == Modality.VIDEO:
stream = ffmpeg.output(stream, output_file, vcodec='libx264', acodec='aac', **{'f': 'mp4'})
From 532cc4ef0f093f6954239f4603708a3861dd8c93 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 11:20:20 +1100
Subject: [PATCH 22/29] Fix tests
---
tests/tensor_search/integ_tests/test_add_documents_combined.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index e7cc77ff2..526ca2c9c 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -1259,7 +1259,7 @@ def test_language_bind_model_can_add_all_private_media_modalities(self):
"image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
"audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
"video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
- "_id": "1"
+ "_id": ""
}
]
for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
From f85ee8b72738915d928217b04b4c6542f059e220 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 11:48:57 +1100
Subject: [PATCH 23/29] Convert image to RGB for languagebind
---
.../languagebind/image/processing_image.py | 5 ++++
.../test_add_documents_combined.py | 27 ++++++++++---------
2 files changed, 19 insertions(+), 13 deletions(-)
diff --git a/src/marqo/s2_inference/languagebind/image/processing_image.py b/src/marqo/s2_inference/languagebind/image/processing_image.py
index 7a3d7c396..90f80b155 100644
--- a/src/marqo/s2_inference/languagebind/image/processing_image.py
+++ b/src/marqo/s2_inference/languagebind/image/processing_image.py
@@ -13,10 +13,15 @@ def make_list_of_images(x):
return x
+def _convert_to_rgb(image):
+ return image.convert("RGB")
+
+
def get_image_transform(config):
config = config.vision_config
transform = transforms.Compose(
[
+ _convert_to_rgb,
transforms.ToTensor(),
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 526ca2c9c..b5498b67f 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -1247,22 +1247,23 @@ def test_language_bind_model_can_add_all_media_modalities(self):
def test_language_bind_model_can_add_all_private_media_modalities(self):
documents = [
{ # With extensions
- "text_field_1": "This is a test text",
+ #"text_field_1": "This is a test text",
"image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
- "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
- "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
- "_id": "1"
+ # "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
+ # "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
+ # "_id": "1"
},
- {
- # No extensions
- "text_field_1": "This is a test text",
- "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
- "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
- "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
- "_id": ""
- }
+ # {
+ # # No extensions
+ # "text_field_1": "This is a test text",
+ # "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ # "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
+ # "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
+ # "_id": "2"
+ # }
]
- for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ # for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
+ for index_name in [self.structured_language_bind_index_name]:
tensor_fields = ["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"] \
if index_name == self.unstructured_language_bind_index_name else None
with self.subTest(index_name):
From bace6fde6a53b4f1ef3c2c4833054bfcf669d182 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 12:16:27 +1100
Subject: [PATCH 24/29] Fix tests
---
.../s2_inference/multimodal_model_load.py | 2 +-
.../test_add_documents_combined.py | 27 +++++++++----------
2 files changed, 14 insertions(+), 15 deletions(-)
diff --git a/src/marqo/s2_inference/multimodal_model_load.py b/src/marqo/s2_inference/multimodal_model_load.py
index ad4bb2506..593c45bf5 100644
--- a/src/marqo/s2_inference/multimodal_model_load.py
+++ b/src/marqo/s2_inference/multimodal_model_load.py
@@ -292,7 +292,7 @@ def encode(self, content, modality, normalize=True, media_download_headers: Opti
# If media has already been preprocessed
inputs[modality.value] = to_device(content[0], self.model.device)['pixel_values']
elif isinstance(content[0], str) and 'http' in content[0]:
- return self.encode(content[0], modality=modality)
+ return self.encode(content[0], modality=modality, media_download_headers=media_download_headers)
else:
raise ValueError(f"Unsupported {modality.value} content type: {type(content)}, content: {content}")
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index b5498b67f..7201ff4ab 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -1247,23 +1247,22 @@ def test_language_bind_model_can_add_all_media_modalities(self):
def test_language_bind_model_can_add_all_private_media_modalities(self):
documents = [
{ # With extensions
- #"text_field_1": "This is a test text",
+ "text_field_1": "This is a test text",
"image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
- # "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
- # "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
- # "_id": "1"
+ "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
+ "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
+ "_id": "1"
},
- # {
- # # No extensions
- # "text_field_1": "This is a test text",
- # "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
- # "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
- # "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
- # "_id": "2"
- # }
+ {
+ # No extensions
+ "text_field_1": "This is a test text",
+ "image_field_1": "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small",
+ "audio_field_1": "https://d2k91vq0avo7lq.cloudfront.net/bark",
+ "video_field_1": "https://d2k91vq0avo7lq.cloudfront.net/congress",
+ "_id": "2"
+ }
]
- # for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
- for index_name in [self.structured_language_bind_index_name]:
+ for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
tensor_fields = ["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"] \
if index_name == self.unstructured_language_bind_index_name else None
with self.subTest(index_name):
From 72e4b366e634a1b0b16942a53069e4bf1c3dfbaa Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 12:33:38 +1100
Subject: [PATCH 25/29] Delete a test
---
.../integ_tests/test_search_combined.py | 46 -------------------
1 file changed, 46 deletions(-)
diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py
index f7e826860..fb8fa9d4b 100644
--- a/tests/tensor_search/integ_tests/test_search_combined.py
+++ b/tests/tensor_search/integ_tests/test_search_combined.py
@@ -1012,52 +1012,6 @@ def test_search_over_private_images_with_media_download_headers(self):
media_download_headers={"marqo_media_header": "media_header_test_key"}
)
- def test_lexical_search_DoesNotErrorWithEscapedQuotes(self):
- """
- Ensure that lexical search handles double quotes properly, both escaped and wrong quotes.
- Expected behavior: escaped quotes are passed to vespa. Incorrect quotes are treated like whitespace.
- """
-
- docs_list = [
- {"_id": "doc1", "text_field_1": '1"2'},
- {"_id": "doc2", "text_field_1": 'exact match'},
- {"_id": "doc3", "text_field_1": 'exacto wrong syntax'},
- {"_id": "doc4", "text_field_1": '"escaped"'},
-
- {"_id": "red_herring_1", "text_field_1": '12'},
- {"_id": "red_herring_2", "text_field_1": 'escaped'},
- {"_id": "red_herring_3", "text_field_1": 'wrong"'}
- ]
- test_cases = [
- ('1\\"2', ['doc1']), # Match off of '1"2'
- ('"exact match"', ['doc2']), # Match off of 'exact match'
- ('\\"escaped\\"', ['doc4', 'red_herring_2']), # Match off of 'escaped' or '"escaped"'
- ('"exacto" wrong"', ['doc3']), # Match properly off of 'wrong'
- ('""', []), # Single quote should return no results (treated as whitespace)
- ('"', []), # Double quote should return no results (treated as whitespace)
- ('', []) # Empty string should return no results
- ]
-
- for index in [self.unstructured_default_text_index, self.structured_default_text_index]:
- with self.subTest(index=index.type):
- tensor_search.add_documents(
- config=self.config,
- add_docs_params=AddDocsParams(
- index_name=index.name,
- docs=docs_list,
- tensor_fields=["text_field_1"] if isinstance(index, UnstructuredMarqoIndex) else None
- )
- )
-
- for query, expected_ids in test_cases:
- with self.subTest(query=query):
- res = tensor_search.search(
- text=query, config=self.config, index_name=index.name,
- search_method=SearchMethod.LEXICAL
- )
- self.assertEqual(len(expected_ids), len(res['hits']))
- self.assertEqual(set(expected_ids), {hit['_id'] for hit in res['hits']})
-
@pytest.mark.largemodel
class TestLanguageBindModelSearchCombined(MarqoTestCase):
From 96e22311d50888f5b61781248e22252f1ced0fe6 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 12:35:03 +1100
Subject: [PATCH 26/29] Add back the test
---
.../integ_tests/test_search_combined.py | 195 +++++-------------
1 file changed, 48 insertions(+), 147 deletions(-)
diff --git a/tests/tensor_search/integ_tests/test_search_combined.py b/tests/tensor_search/integ_tests/test_search_combined.py
index fb8fa9d4b..514a92e99 100644
--- a/tests/tensor_search/integ_tests/test_search_combined.py
+++ b/tests/tensor_search/integ_tests/test_search_combined.py
@@ -1,24 +1,23 @@
import os
import uuid
from unittest import mock
-
-import pytest
import torch
+import pytest
import marqo.core.exceptions as core_exceptions
-from marqo import exceptions as base_exceptions
-from marqo.core.models.add_docs_params import AddDocsParams
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import FieldRequest
+from marqo.tensor_search import tensor_search
+from marqo.tensor_search.enums import SearchMethod
+from marqo.core.models.add_docs_params import AddDocsParams
+from tests.marqo_test import MarqoTestCase, TestImageUrls
+from marqo import exceptions as base_exceptions
from marqo.core.models.marqo_query import MarqoLexicalQuery
from marqo.core.models.score_modifier import ScoreModifierType, ScoreModifier
from marqo.core.structured_vespa_index.structured_vespa_index import StructuredVespaIndex
from marqo.core.unstructured_vespa_index.unstructured_vespa_index import UnstructuredVespaIndex
-from marqo.s2_inference.errors import MediaDownloadError
-from marqo.tensor_search import tensor_search
-from marqo.tensor_search.enums import SearchMethod
from marqo.tensor_search.models.api_models import SearchQuery
-from tests.marqo_test import MarqoTestCase, TestImageUrls, TestAudioUrls, TestVideoUrls
+from pydantic import ValidationError
class TestSearch(MarqoTestCase):
@@ -205,7 +204,7 @@ def test_search_video(self):
documents = [
{"video_field_1": "https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4", "_id": "1"},
# Replace the audio link with something marqo-hosted
- {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
+ {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
{"image_field_1": TestImageUrls.HIPPO_REALISTIC_LARGE.value, "_id": "3"},
# {"image_field_1": TestImageUrls.HIPPO_REALISTIC.value, "_id": "5"}, # png image with palette is not supported
{"text_field_1": "hello there padawan. Today you will begin your training to be a Jedi", "_id": "4"},
@@ -240,7 +239,7 @@ def test_search_audio(self):
documents = [
{"video_field_1": "https://marqo-k400-video-test-dataset.s3.amazonaws.com/videos/---QUuC4vJs_000084_000094.mp4", "_id": "1"},
# Replace the audio link with something marqo-hosted
- {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
+ {"audio_field_1": "https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3", "_id": "2"},
{"image_field_1": TestImageUrls.HIPPO_REALISTIC_LARGE.value, "_id": "3"},
# {"image_field_1": TestImageUrls.HIPPO_REALISTIC.value, "_id": "5"}, # png file with palette is not supported
{"text_field_1": "hello there padawan. Today you will begin your training to be a Jedi", "_id": "4"},
@@ -263,7 +262,7 @@ def test_search_audio(self):
index_name=index.name,
text="https://marqo-ecs-50-audio-test-dataset.s3.amazonaws.com/audios/marqo-audio-test.mp3"
)
-
+
# Assertions
self.assertEqual(len(results['hits']), 3) # 3 documents should be returned (limit=3)
self.assertEqual(results['hits'][0]['_id'], "2") # The audio document should be the top result
@@ -968,146 +967,48 @@ def test_search_query_CanAcceptDifferentSearchMethods(self):
search_query = SearchQuery(q="test")
self.assertEqual(SearchMethod.TENSOR, search_query.searchMethod)
- def test_search_private_images_proper_error_raised(self):
- """Test that search raises a MediaDownloadError when trying to access private images"""
- test_indexes = [
- self.unstructured_default_image_index,
- self.structured_default_image_index
- ]
+ def test_lexical_search_DoesNotErrorWithEscapedQuotes(self):
+ """
+ Ensure that lexical search handles double quotes properly, both escaped and wrong quotes.
+ Expected behavior: escaped quotes are passed to vespa. Incorrect quotes are treated like whitespace.
+ """
- test_queries = [({
- "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png": 1,
- "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1 }, "dictionary queries"),
- ("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "str queries")]
- for index_name in test_indexes:
- for query, msg in test_queries:
- with self.subTest(msg=f"index: {index_name}, query: {msg}"):
- with self.assertRaises(MediaDownloadError):
- _ = tensor_search.search(
- config=self.config,
- index_name=index_name.name,
- text=query,
- search_method=SearchMethod.TENSOR,
- )
+ docs_list = [
+ {"_id": "doc1", "text_field_1": '1"2'},
+ {"_id": "doc2", "text_field_1": 'exact match'},
+ {"_id": "doc3", "text_field_1": 'exacto wrong syntax'},
+ {"_id": "doc4", "text_field_1": '"escaped"'},
- def test_search_over_private_images_with_media_download_headers(self):
- """Test that search can use private images with media download headers"""
- test_indexes = [
- self.unstructured_default_image_index,
- self.structured_default_image_index
+ {"_id": "red_herring_1", "text_field_1": '12'},
+ {"_id": "red_herring_2", "text_field_1": 'escaped'},
+ {"_id": "red_herring_3", "text_field_1": 'wrong"'}
+ ]
+ test_cases = [
+ ('1\\"2', ['doc1']), # Match off of '1"2'
+ ('"exact match"', ['doc2']), # Match off of 'exact match'
+ ('\\"escaped\\"', ['doc4', 'red_herring_2']), # Match off of 'escaped' or '"escaped"'
+ ('"exacto" wrong"', ['doc3']), # Match properly off of 'wrong'
+ ('""', []), # Single quote should return no results (treated as whitespace)
+ ('"', []), # Double quote should return no results (treated as whitespace)
+ ('', []) # Empty string should return no results
]
- test_queries = [({
- "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png": 1,
- "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small": 1 }, "dictionary queries"),
- ("https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small", "str queries")]
- for index_name in test_indexes:
- for query, msg in test_queries:
- with self.subTest(msg=f"index: {index_name}, query: {msg}"):
- _ = tensor_search.search(
- config=self.config,
- index_name=index_name.name,
- text=query,
- search_method=SearchMethod.TENSOR,
- media_download_headers={"marqo_media_header": "media_header_test_key"}
+ for index in [self.unstructured_default_text_index, self.structured_default_text_index]:
+ with self.subTest(index=index.type):
+ tensor_search.add_documents(
+ config=self.config,
+ add_docs_params=AddDocsParams(
+ index_name=index.name,
+ docs=docs_list,
+ tensor_fields=["text_field_1"] if isinstance(index, UnstructuredMarqoIndex) else None
)
-
-
-@pytest.mark.largemodel
-class TestLanguageBindModelSearchCombined(MarqoTestCase):
- """A class to test the search with the LanguageBind model."""
-
- @classmethod
- def setUpClass(cls) -> None:
- super().setUpClass()
-
- structured_language_bind_index = cls.structured_marqo_index_request(
- name="structured_image_index" + str(uuid.uuid4()).replace('-', ''),
- fields=[
- FieldRequest(name="text_field_1", type=FieldType.Text,
- features=[FieldFeature.Filter, FieldFeature.LexicalSearch]),
- FieldRequest(name="image_field_1", type=FieldType.ImagePointer),
- FieldRequest(name="audio_field_1", type=FieldType.AudioPointer),
- FieldRequest(name="video_field_1", type=FieldType.VideoPointer),
- FieldRequest(
- name="multimodal_field",
- type=FieldType.MultimodalCombination,
- dependent_fields={
- "image_field_1": 1.0,
- "text_field_1": 1.0,
- "audio_field_1": 1.0,
- "video_field_1": 1.0,
- }
)
- ],
- model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
- tensor_fields=["text_field_1", "image_field_1", "audio_field_1", "video_field_1", "multimodal_field"],
- )
- unstructured_language_bind_index = cls.unstructured_marqo_index_request(
- name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''),
- model=Model(name="LanguageBind/Video_V1.5_FT_Audio_FT_Image"),
- treat_urls_and_pointers_as_images=True,
- treat_urls_and_pointers_as_media=True
- )
-
- cls.indexes = cls.create_indexes([structured_language_bind_index, unstructured_language_bind_index])
-
- cls.structured_language_bind_index_name = structured_language_bind_index.name
- cls.unstructured_language_bind_index_name = unstructured_language_bind_index.name
-
- s2_inference.clear_loaded_models()
-
- @classmethod
- def tearDownClass(cls) -> None:
- super().tearDownClass()
- s2_inference.clear_loaded_models()
-
- def test_language_bind_model_can_search_all_media_modalities(self):
- """Test to ensure that the LanguageBind model can search all media types to the index"""
- queries = [
- "This is a test text",
- TestImageUrls.IMAGE1.value,
- TestAudioUrls.AUDIO1.value,
- TestVideoUrls.VIDEO1.value,
- {
- "This is a test text": 1,
- TestImageUrls.IMAGE1.value: 1,
- TestAudioUrls.AUDIO1.value: 1,
- TestVideoUrls.VIDEO1.value: 1
- }
- ]
- for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
- for query in queries:
- with self.subTest(index_name):
- _ = tensor_search.search(
- config = self.config,
- index_name=index_name,
- text=query,
- search_method=SearchMethod.TENSOR
- )
-
- def test_language_bind_model_can_search_all_private_media_modalities(self):
- """A test to ensure that the LanguageBind model can search all private media types to the index"""
- queries = [
- "This is a test text",
- "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png",
- "https://d2k91vq0avo7lq.cloudfront.net/bark.wav",
- "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4",
- {
- "This is a test text": 1,
- "https://d2k91vq0avo7lq.cloudfront.net/ai_hippo_realistic_small.png": 1,
- "https://d2k91vq0avo7lq.cloudfront.net/bark.wav": 1,
- "https://d2k91vq0avo7lq.cloudfront.net/congress.mp4": 1
- }
- ]
- for index_name in [self.structured_language_bind_index_name, self.unstructured_language_bind_index_name]:
- for query in queries:
- with self.subTest(index_name):
- _ = tensor_search.search(
- config = self.config,
- index_name=index_name,
- text=query,
- search_method=SearchMethod.TENSOR,
- media_download_headers={"marqo_media_header": "media_header_test_key"}
- )
\ No newline at end of file
+ for query, expected_ids in test_cases:
+ with self.subTest(query=query):
+ res = tensor_search.search(
+ text=query, config=self.config, index_name=index.name,
+ search_method=SearchMethod.LEXICAL
+ )
+ self.assertEqual(len(expected_ids), len(res['hits']))
+ self.assertEqual(set(expected_ids), {hit['_id'] for hit in res['hits']})
From 2353e6a68daeb5d59a785153a0e4385896f6426b Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 12:56:04 +1100
Subject: [PATCH 27/29] Change largemodel tests logic
---
tests/conftest.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/conftest.py b/tests/conftest.py
index 36d1b9617..93d52e8ed 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,9 +18,9 @@ def pytest_collection_modifyitems(config, items):
skip_cpu_only = pytest.mark.skip(reason="skip in --largemodel mode when cpu_only is present")
if config.getoption("--largemodel"):
- # --largemodel given in cli: do not skip largemodel tests, skip cpu_only tests
+ # --largemodel given in cli: only run tests that have largemodel marker
for item in items:
- if "cpu_only" in item.keywords:
+ if "largemodel" not in item.keywords:
item.add_marker(skip_cpu_only)
else:
for item in items:
From 573b46907517556b6299e651986d3f49a9c7b8a4 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 13:48:07 +1100
Subject: [PATCH 28/29] Fix tests
---
.../tensor_search/test_modalities_download.py | 39 ++++++++++++++-----
1 file changed, 30 insertions(+), 9 deletions(-)
diff --git a/tests/tensor_search/test_modalities_download.py b/tests/tensor_search/test_modalities_download.py
index b7158b2be..0335d2a48 100644
--- a/tests/tensor_search/test_modalities_download.py
+++ b/tests/tensor_search/test_modalities_download.py
@@ -1,17 +1,21 @@
import unittest
from unittest.mock import Mock, patch, MagicMock
-from PIL import UnidentifiedImageError
+
+import ffmpeg
+import pytest
import torch
-from marqo.s2_inference.errors import UnsupportedModalityError, S2InferenceError
-from marqo.tensor_search.add_docs import threaded_download_and_preprocess_content
+from PIL import UnidentifiedImageError
+
from marqo.core.models.marqo_index import IndexType, MarqoIndex, FieldType
-from marqo.s2_inference.s2_inference import Modality
-from marqo.s2_inference.models.model_type import ModelType
-from marqo.tensor_search.telemetry import RequestMetricsStore, RequestMetrics
from marqo.s2_inference.errors import MediaDownloadError
-import ffmpeg
+from marqo.s2_inference.errors import UnsupportedModalityError, S2InferenceError
+from marqo.s2_inference.models.model_type import ModelType
+from marqo.s2_inference.s2_inference import Modality
+from marqo.tensor_search.add_docs import threaded_download_and_preprocess_content
+from marqo.tensor_search.telemetry import RequestMetrics
+@pytest.mark.unittest
class TestThreadedDownloadAndPreprocess(unittest.TestCase):
def setUp(self):
@@ -230,13 +234,30 @@ def test_video_and_audio_unstructured_index(self, mock_infer_modality, mock_down
# Verify that download_and_chunk_media was called twice
self.assertEqual(mock_download_and_chunk.call_count, 2)
+ print(mock_download_and_chunk.call_args_list)
# Verify the calls to download_and_chunk_media
mock_download_and_chunk.assert_any_call(
- self.mock_video_url, "cpu", {}, Modality.VIDEO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
+ url=self.mock_video_url,
+ device='cpu',
+ modality= Modality.VIDEO,
+ marqo_index_type = self.mock_marqo_index.type,
+ marqo_index_model = self.mock_marqo_index.model,
+ preprocessors = None,
+ audio_preprocessing = None,
+ video_preprocessing = None,
+ media_download_headers = {}
)
mock_download_and_chunk.assert_any_call(
- self.mock_audio_url, "cpu", {}, Modality.AUDIO, self.mock_marqo_index.type, self.mock_marqo_index.model, None, None, None
+ url=self.mock_video_url,
+ device='cpu',
+ modality= Modality.VIDEO,
+ marqo_index_type = self.mock_marqo_index.type,
+ marqo_index_model = self.mock_marqo_index.model,
+ preprocessors = None,
+ audio_preprocessing = None,
+ video_preprocessing = None,
+ media_download_headers = {}
)
@patch("marqo.tensor_search.add_docs.download_and_chunk_media")
From a5ea94788cf671f30e917b77d5c404702df3eb53 Mon Sep 17 00:00:00 2001
From: Li Wan
Date: Fri, 25 Oct 2024 13:50:37 +1100
Subject: [PATCH 29/29] Fix tests
---
.../integ_tests/test_add_documents_combined.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py
index 7201ff4ab..64df57273 100644
--- a/tests/tensor_search/integ_tests/test_add_documents_combined.py
+++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py
@@ -833,13 +833,13 @@ def test_process_media_chunk_calculation(self, mock_temp_dir, mock_ffmpeg):
processor = streaming_media_processor.StreamingMediaProcessor(
url='http://example.com/video.mp4',
device='cpu',
- headers={},
modality=streaming_media_processor.Modality.VIDEO,
marqo_index_type=IndexType.Unstructured,
marqo_index_model=Model(name="test", properties={}),
audio_preprocessing=unittest.mock.Mock(),
video_preprocessing=unittest.mock.Mock(),
- preprocessors={'video': unittest.mock.Mock()}
+ preprocessors={'video': unittest.mock.Mock()},
+ media_download_headers={},
)
# Set arbitrary values
@@ -1167,8 +1167,6 @@ def test_add_private_images_success(self):
self.assertFalse(res.errors)
-
-
@pytest.mark.largemodel
class TestLanguageBindModelAddDocumentCombined(MarqoTestCase):
"""A class to test the add_documents with the LanguageBind model."""