Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added model_auth to search, bulk_search, and add_docs #87

Merged
merged 3 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]]
highlights=None, device: Optional[str] = None, filter_string: str = None,
show_highlights=True, reranker=None, image_download_headers: Optional[Dict] = None,
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict[str,List[Union[float, int]]]] = None,
context: Optional[dict] = None, score_modifiers: Optional[dict] = None,
context: Optional[dict] = None, score_modifiers: Optional[dict] = None, model_auth: Optional[dict] = None
) -> Dict[str, Any]:
"""Search the index.

Expand All @@ -145,6 +145,7 @@ def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]]
retrieved.
context: a dictionary to allow you to bring your own vectors and more into search.
score_modifiers: a dictionary to modify the score based on field values, for tensor search only
model_auth: authorisation that lets Marqo download a private model, if required
Returns:
Dictionary with hits and other metadata
"""
Expand Down Expand Up @@ -180,6 +181,8 @@ def search(self, q: Union[str, dict], searchable_attributes: Optional[List[str]]
body["context"] = context
if score_modifiers is not None:
body["scoreModifiers"] = score_modifiers
if model_auth is not None:
body["modelAuth"] = model_auth
res = self.http.post(
path=path_with_query_str,
body=body
Expand Down Expand Up @@ -246,6 +249,7 @@ def add_documents(
use_existing_tensors: bool = False,
image_download_headers: dict = None,
mappings: dict = None,
model_auth: dict = None
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Add documents to this index. Does a partial update on existing documents,
based on their ID. Adds unseen documents to the index.
Expand All @@ -269,7 +273,7 @@ def add_documents(
image_download_headers: a dictionary of headers to be passed while downloading images,
for URLs found in documents
mappings: a dictionary to help handle the object fields. e.g., multimodal_combination field

model_auth: used to authorise a private model
Returns:
Response body outlining indexing result
"""
Expand All @@ -281,7 +285,8 @@ def add_documents(
update_method="replace",
documents=documents, auto_refresh=auto_refresh, server_batch_size=server_batch_size,
client_batch_size=client_batch_size, processes=processes, device=device, non_tensor_fields=non_tensor_fields,
use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, mappings = mappings
use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, mappings=mappings,
model_auth=model_auth
)

def update_documents(
Expand Down Expand Up @@ -342,7 +347,8 @@ def _generic_add_update_docs(
non_tensor_fields: List = None,
use_existing_tensors: bool = False,
image_download_headers: dict = None,
mappings: dict = None
mappings: dict = None,
model_auth: dict = None
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:

error_detected_message = ('Errors detected in add documents call. '
Expand All @@ -360,6 +366,7 @@ def _generic_add_update_docs(
non_tensor_fields_query_param = utils.convert_list_to_query_params("non_tensor_fields", non_tensor_fields)
image_download_headers_param = (utils.convert_dict_to_url_params(image_download_headers)
if image_download_headers else '')
model_auth_param = (utils.convert_dict_to_url_params(model_auth) if model_auth else '')
mappings_param = (utils.convert_dict_to_url_params(mappings) if mappings else '')
query_str_params = (
f"{f'&device={utils.translate_device_string_for_url(selected_device)}'}"
Expand All @@ -369,6 +376,7 @@ def _generic_add_update_docs(
f"{f'&{non_tensor_fields_query_param}' if len(non_tensor_fields) > 0 else ''}"
f"{f'&image_download_headers={image_download_headers_param}' if image_download_headers else ''}"
f"{f'&mappings={mappings_param}' if mappings else ''}"
f"{f'&model_auth={model_auth_param}' if model_auth_param else ''}"
)
end_time_client_process = timer()
total_client_process_time = end_time_client_process - start_time_client_process
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SearchBody(BaseMarqoModel):
image_download_headers: Optional[Dict] = None
context: Optional[Dict] = None
scoreModifiers: Optional[Dict] = None

modelAuth: Optional[Dict] = None

class BulkSearchBody(SearchBody):
index: str
Expand Down
4 changes: 2 additions & 2 deletions src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__marqo_version__ = "0.0.18"
__marqo_release_page__ = "https://github.com/marqo-ai/marqo/releases/tag/0.0.18"
__marqo_version__ = "0.0.19"
__marqo_release_page__ = "https://github.com/marqo-ai/marqo/releases/tag/0.0.19"


def supported_marqo_version() -> str:
Expand Down
93 changes: 93 additions & 0 deletions tests/v0_tests/test_model_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import json
from marqo.client import Client
from marqo.errors import MarqoApiError, MarqoError, MarqoWebError
from tests.marqo_test import MarqoTestCase
from marqo.utils import convert_dict_to_url_params
from unittest import mock


class TestAddDocuments(MarqoTestCase):

def setUp(self) -> None:
self.client = Client(**self.client_settings)
self.index_name_1 = "my-test-index-1"
try:
self.client.delete_index(self.index_name_1)
except MarqoApiError as s:
pass

def test_add_docs_model_auth(self):
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
mock_s3_model_auth = {'s3': {'aws_access_key_id': 'some_acc_key',
'aws_secret_access_key': 'some_sec_acc_key'}}
expected_str = f"&model_auth={convert_dict_to_url_params(mock_s3_model_auth)}"
self.client.index(index_name=self.index_name_1).add_documents(
documents=[{"some": "data"}], model_auth=mock_s3_model_auth)
args, kwargs = mock__post.call_args
assert expected_str in kwargs['path']

return True

assert run()

def test_add_docs_model_client_batching(self):
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
mock_s3_model_auth = {'s3': {'aws_access_key_id': 'some_acc_key',
'aws_secret_access_key': 'some_sec_acc_key'}}
expected_str = f"&model_auth={convert_dict_to_url_params(mock_s3_model_auth)}"
self.client.index(index_name=self.index_name_1).add_documents(
documents=[{"some": f"data {i}"} for i in range(20)], model_auth=mock_s3_model_auth,
client_batch_size=10
)
for call in mock__post.call_args_list:
_, kwargs = call
assert expected_str in kwargs['path'] or ('refresh' in kwargs['path'])

assert len(mock__post.call_args_list) == 3

return True

assert run()

def test_search_model_auth(self):
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
mock_s3_model_auth = {'s3': {'aws_access_key_id': 'some_acc_key',
'aws_secret_access_key': 'some_sec_acc_key'}}
self.client.index(index_name=self.index_name_1).search(
q='something', model_auth=mock_s3_model_auth)
args, kwargs = mock__post.call_args
assert kwargs['body']['modelAuth'] == mock_s3_model_auth

return True

assert run()

def test_bulk_search_model_auth(self):
mock__post = mock.MagicMock()

@mock.patch("marqo._httprequests.HttpRequests.post", mock__post)
def run():
mock_s3_model_auth = {'s3': {'aws_access_key_id': 'some_acc_key',
'aws_secret_access_key': 'some_sec_acc_key'}}

self.client.bulk_search([{
"index": self.index_name_1,
"q": "a",
"modelAuth": mock_s3_model_auth
}])

args, kwargs = mock__post.call_args
assert json.loads( kwargs['body'])['queries'][0]['modelAuth'] == mock_s3_model_auth

return True

assert run()