Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image download headers #60

Merged
merged 4 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def refresh(self):
def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]] = None,
limit: int = 10, offset: int = 0, search_method: Union[SearchMethods.TENSOR, str] = SearchMethods.TENSOR,
highlights=None, device: Optional[str] = None, filter_string: str = None,
show_highlights=True, reranker=None,
show_highlights=True, reranker=None, image_download_headers: Optional[Dict] = None,
pandu-k marked this conversation as resolved.
Show resolved Hide resolved
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict[str,List[Union[float, int]]]] = None,
) -> Dict[str, Any]:
"""Search the index.
Expand Down Expand Up @@ -172,6 +172,8 @@ def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]]
body["attributesToRetrieve"] = attributes_to_retrieve
if filter_string is not None:
body["filter"] = filter_string
if image_download_headers is not None:
body["image_download_headers"] = image_download_headers
res = self.http.post(
path=path_with_query_str,
body=body
Expand Down Expand Up @@ -234,7 +236,8 @@ def add_documents(
processes: int = None,
device: str = None,
non_tensor_fields: List[str] = None,
use_existing_tensors: bool = False
use_existing_tensors: bool = False,
image_download_headers: dict = None
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Add documents to this index. Does a partial update on existing documents,
based on their ID. Adds unseen documents to the index.
Expand All @@ -255,17 +258,21 @@ def add_documents(
"cuda" and "cuda:2"
non_tensor_fields: fields within documents to not create and store tensors against.
use_existing_tensors: use vectors that already exist in the docs.
image_download_headers: a dictionary of headers to be passed while downloading images,
for URLs found in documents

Returns:
Response body outlining indexing result
"""
if non_tensor_fields is None:
non_tensor_fields = []
if image_download_headers is None:
image_download_headers = dict()
return self._generic_add_update_docs(
update_method="replace",
documents=documents, auto_refresh=auto_refresh, server_batch_size=server_batch_size,
client_batch_size=client_batch_size, processes=processes, device=device, non_tensor_fields=non_tensor_fields,
use_existing_tensors=use_existing_tensors
use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers
)

def update_documents(
Expand Down Expand Up @@ -320,10 +327,13 @@ def _generic_add_update_docs(
processes: int = None,
device: str = None,
non_tensor_fields: List = None,
use_existing_tensors: bool = None
use_existing_tensors: bool = False,
image_download_headers: dict = None
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:

if non_tensor_fields is None:
non_tensor_fields = []

selected_device = device if device is not None else self.config.indexing_device
num_docs = len(documents)

Expand All @@ -332,12 +342,15 @@ def _generic_add_update_docs(
start_time_client_process = timer()
base_path = f"indexes/{self.index_name}/documents"
non_tensor_fields_query_param = utils.convert_list_to_query_params("non_tensor_fields", non_tensor_fields)
image_download_headers_param = (utils.convert_dict_to_url_params(image_download_headers)
if image_download_headers else '')
query_str_params = (
f"{f'&device={utils.translate_device_string_for_url(selected_device)}'}"
f"{f'&processes={processes}' if processes is not None else ''}"
f"{f'&batch_size={server_batch_size}' if server_batch_size is not None else ''}"
f"{f'&use_existing_tensors={str(use_existing_tensors).lower()}' if use_existing_tensors is not None else ''}"
f"{f'&{non_tensor_fields_query_param}' if len(non_tensor_fields) > 0 else ''}"
f"{f'&image_download_headers={image_download_headers_param}' if image_download_headers else ''}"
)
end_time_client_process = timer()
total_client_process_time = end_time_client_process - start_time_client_process
Expand Down
19 changes: 18 additions & 1 deletion src/marqo/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import urllib.parse
from marqo import errors
from typing import Any, Optional, List

Expand Down Expand Up @@ -37,6 +39,7 @@ def translate_device_string_for_url(device: Optional[str]) -> Optional[str]:
lowered_device = device.lower()
return lowered_device.replace(":", "")


def convert_list_to_query_params(query_param: str, x: List[Any]) -> str:
""" Converts a list value for a query parameter to its query string.

Expand All @@ -47,4 +50,18 @@ def convert_list_to_query_params(query_param: str, x: List[Any]) -> str:
Returns:
A rendered query string for the given parameter and parameter value.
"""
return "&".join([f"{query_param}={str(xx)}" for xx in x])
return "&".join([f"{query_param}={str(xx)}" for xx in x])


def convert_dict_to_url_params(d: dict) -> str:
"""Converts a dict into a url-encoded string that can be appended as a query_param
Args:
d: dict to be converted

Returns:
A URL-encoded string
"""
as_str = json.dumps(d)
url_encoded = urllib.parse.quote_plus(as_str)
return url_encoded