diff --git a/requirements.txt b/requirements.txt index 21e62b50f..7015221c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,8 @@ pytest tox # s2_inference: more_itertools +boto3==1.25.4 +botocore==1.28.4 nltk==3.7 torch==1.12.1 torchvision==0.13.1 diff --git a/src/marqo/README.md b/src/marqo/README.md index 2243943d8..513d6eefd 100644 --- a/src/marqo/README.md +++ b/src/marqo/README.md @@ -237,3 +237,8 @@ curl http://localhost:8882/openapi.json ``` To get the human readable spec, visit `http://localhost:8882/docs` +## IDE tips + +## PyCharm +Pydantic dataclasses are used in this project. By default, PyCharm can't parse initialisations of these dataclasses. +[This plugin](https://plugins.jetbrains.com/plugin/12861-pydantic) can help. diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py index 5099c2121..266a524c0 100644 --- a/src/marqo/s2_inference/clip_utils.py +++ b/src/marqo/s2_inference/clip_utils.py @@ -1,7 +1,6 @@ -# from torch import FloatTensor -# from typing import Any, Dict, List, Optional, Union import os -import PIL.Image +from marqo.tensor_search.enums import ModelProperties, InferenceParams +from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth import validators import requests import numpy as np @@ -15,7 +14,7 @@ from marqo.s2_inference.logger import get_logger from marqo.s2_inference.errors import IncompatibleModelDeviceError, InvalidModelPropertiesError from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from marqo.s2_inference.processing.custom_clip_utils import HFTokenizer, download_pretrained_from_url +from marqo.s2_inference.processing.custom_clip_utils import HFTokenizer, download_model from torchvision.transforms import InterpolationMode from marqo.s2_inference.configs import ModelCache @@ -205,11 +204,45 @@ def __init__(self, model_type: str = "ViT-B/32", device: str = 'cpu', embedding self.truncate = truncate self.model_properties = kwargs.get("model_properties", dict()) + # model_auth gets passed through add_docs and search requests: + model_auth = kwargs.get(InferenceParams.model_auth, None) + if model_auth is not None: + self.model_auth = model_auth + else: + self.model_auth = None + + def _download_from_repo(self): + """Downloads model from an external repo like s3 and returns the filepath + + Returns: + The model's filepath + + Raises: + RunTimeError if an empty filepath is detected. This is important + because OpenCLIP will instantiate a model with random weights, if + a filepath isn't specified, and the model isn't a publicly + available HF or OpenAI one. + """ + model_location = ModelLocation(**self.model_properties[ModelProperties.model_location]) + download_model_params = {"repo_location": model_location} + + if model_location.auth_required: + download_model_params['auth'] = self.model_auth + + model_file_path = download_model(**download_model_params) + if model_file_path is None or model_file_path == '': + raise RuntimeError( + 'download_model() needs to return a valid filepath to the model! Instead, received ' + f' filepath `{model_file_path}`') + return model_file_path + def load(self) -> None: + model_location_presence = ModelProperties.model_location in self.model_properties + path = self.model_properties.get("localpath", None) or self.model_properties.get("url",None) - if path is None: + if path is None and not model_location_presence: # The original method to load the openai clip model # https://github.com/openai/CLIP/issues/30 self.model, self.preprocess = clip.load(self.model_type, device='cpu', jit=False, download_root=ModelCache.clip_cache_path) @@ -217,10 +250,17 @@ def load(self) -> None: self.tokenizer = clip.tokenize else: logger.info("Detecting custom clip model path. We use generic clip model loading.") - if os.path.isfile(path): + if path and model_location_presence: + raise InvalidModelPropertiesError( + "Only one of `url`, `localpath` or `model_location can be specified in " + "model_properties`. Please ensure that only one of these is specified in " + "model_properties and retry.") + if model_location_presence: + self.model_path = self._download_from_repo() + elif os.path.isfile(path): self.model_path = path elif validators.url(path): - self.model_path = download_pretrained_from_url(path) + self.model_path = download_model(url=path) else: raise InvalidModelPropertiesError(f"Marqo can not load the custom clip model." f"The provided model path `{path}` is neither a local file nor a valid url." @@ -356,23 +396,33 @@ def load(self) -> None: # https://github.com/mlfoundations/open_clip path = self.model_properties.get("localpath", None) or self.model_properties.get("url", None) - if path is None: + model_location_presence = ModelProperties.model_location in self.model_properties + + if path is None and not model_location_presence: self.model, _, self.preprocess = open_clip.create_model_and_transforms(self.model_name, pretrained=self.pretrained, device=self.device, jit=False, cache_dir=ModelCache.clip_cache_path) self.tokenizer = open_clip.get_tokenizer(self.model_name) self.model.eval() else: + if path and model_location_presence: + raise InvalidModelPropertiesError( + "Only one of `url`, `localpath` or `model_location can be specified in " + "model_properties`. Please ensure that only one of these is specified in " + "model_properties and retry.") logger.info("Detecting custom clip model path. We use generic clip model loading.") - if os.path.isfile(path): + if model_location_presence: + self.model_path = self._download_from_repo() + elif os.path.isfile(path): self.model_path = path elif validators.url(path): - self.model_path = download_pretrained_from_url(path) + self.model_path = download_model(url=path) else: - raise InvalidModelPropertiesError(f"Marqo can not load the custom clip model." - f"The provided model path `{path}` is neither a local file nor a valid url." - f"Please check your provided model url and retry." - f"Check `https://docs.marqo.ai/0.0.13/Models-Reference/dense_retrieval/#generic-clip-models` for more info.") + raise InvalidModelPropertiesError( + f"Marqo cannot load the custom clip model. " + f"The provided model path `{path}` is neither a local file nor a valid url. " + f"Please check your provided model url and retry. " + f"Check `https://docs.marqo.ai/0.0.13/Models-Reference/dense_retrieval/#generic-clip-models` for more info.") self.precision = self.model_properties.get("precision", "fp32") self.jit = self.model_properties.get("jit", False) @@ -384,14 +434,13 @@ def load(self) -> None: self.model.eval() - def custom_clip_load(self): self.model_name = self.model_properties.get("name", None) - logger.info(f"The name of the custom clip model is {self.model_name}. We use open_clip load") - model, _, preprocess = open_clip.create_model_and_transforms(model_name=self.model_name, jit = self.jit, pretrained=self.model_path, precision = self.precision, - image_mean=self.mean, image_std=self.std, device = self.device, cache_dir=ModelCache.clip_cache_path) + model, _, preprocess = open_clip.create_model_and_transforms( + model_name=self.model_name, jit = self.jit, pretrained=self.model_path, precision = self.precision, + image_mean=self.mean, image_std=self.std, device = self.device, cache_dir=ModelCache.clip_cache_path) return model, preprocess diff --git a/src/marqo/s2_inference/model_downloading/__init__.py b/src/marqo/s2_inference/model_downloading/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/marqo/s2_inference/model_downloading/from_hf.py b/src/marqo/s2_inference/model_downloading/from_hf.py new file mode 100644 index 000000000..3a201f6b5 --- /dev/null +++ b/src/marqo/s2_inference/model_downloading/from_hf.py @@ -0,0 +1,48 @@ +from marqo.tensor_search.models.external_apis.hf import HfAuth, HfModelLocation +from typing import Optional +from huggingface_hub import hf_hub_download +from marqo.s2_inference.logger import get_logger +from huggingface_hub.utils._errors import RepositoryNotFoundError +from marqo.s2_inference.errors import ModelDownloadError + +logger = get_logger(__name__) + + +def download_model_from_hf( + location: HfModelLocation, + auth: Optional[HfAuth] = None, + download_dir: Optional[str] = None): + """Downloads a pretrained model from HF, if it doesn't exist locally. The basename of the + location's filename is used as the local filename. + + hf_hub_download downloads the model if it does not yet exist in the cache. + + Args: + location: repo_id and filename to be downloaded. + auth: contains HF API token for model access + download_dir: [not yet implemented]. The location where the model + should be stored + + Returns: + Path to the downloaded model + """ + if download_dir is not None: + logger.warning( + "Hugging Face model download was given the `download_dir` argument, " + "even though it is not yet implemented. " + "The specified model will be downloaded but the `download_dir` " + "parameter will be ignored." + ) + download_kwargs = location.dict() + if auth is not None: + download_kwargs = {**download_kwargs, **auth.dict()} + try: + return hf_hub_download(**download_kwargs) + except RepositoryNotFoundError: + # TODO: add link to HF model auth/loc + raise ModelDownloadError( + "Could not find the specified Hugging Face model repository. Please ensure that the request's model_auth's " + "`hf` credentials and the index's model_location are correct. " + "If the index's model_location is not correct, please create a new index with the corrected model_location" + ) + diff --git a/src/marqo/s2_inference/model_downloading/from_s3.py b/src/marqo/s2_inference/model_downloading/from_s3.py new file mode 100644 index 000000000..d59b45bc3 --- /dev/null +++ b/src/marqo/s2_inference/model_downloading/from_s3.py @@ -0,0 +1,74 @@ +import os +from marqo.s2_inference.configs import ModelCache +from marqo.tensor_search.models.external_apis.s3 import S3Auth, S3Location +from typing import Optional +import boto3 +from marqo.s2_inference.errors import ModelDownloadError +from botocore.exceptions import NoCredentialsError + + +def get_presigned_s3_url(location: S3Location, auth: Optional[S3Auth] = None): + """Returns the s3 url of a request to get an S3 object + + Args: + location: Bucket and key of model file to be downloaded + auth: AWS IAM access keys to a user with access to the model to be downloaded + + Returns: + The the presigned s3 URL + + TODO: add link to proper usage in error messages + """ + if auth is None: + raise ModelDownloadError( + "Error retrieving private model. s3 authorisation information is required to " + "download a model from an s3 bucket. " + "If the model is publicly accessible, please use the model's publicly accessible URL." + ) + s3_client = boto3.client('s3', **auth.dict()) + try: + return s3_client.generate_presigned_url('get_object', Params=location.dict()) + except NoCredentialsError: + raise ModelDownloadError( + "Error retrieving private model. AWS credentials were not accepted." + ) + + +def get_s3_model_absolute_cache_path(location: S3Location) -> str: + """Returns the absolute path of an s3 model if it were downloaded. + + Args: + location: Bucket and key of model file to be downloaded + + Returns: + The absolute path of an s3 model if it were downloaded. + """ + cache_dir = os.path.expanduser(ModelCache.clip_cache_path) + return os.path.join(cache_dir, get_s3_model_cache_filename(location)) + + +def check_s3_model_already_exists(location: S3Location) -> bool: + """Returns True iff an s3 model is already downloaded + + Args: + location: Bucket and key of model file to be downloaded + + Returns: + The model cache filename of an s3 object + """ + abs_path = get_s3_model_absolute_cache_path(location) + return os.path.isfile(abs_path) + + +def get_s3_model_cache_filename(location: S3Location) -> str: + """Returns the model cache filename of an s3 object + + Args: + location: Bucket and key of model file to be downloaded + + Returns: + The model cache filename of an s3 object + """ + return os.path.basename(location.Key) + + diff --git a/src/marqo/s2_inference/processing/custom_clip_utils.py b/src/marqo/s2_inference/processing/custom_clip_utils.py index 847f66ef6..28c3deb7b 100644 --- a/src/marqo/s2_inference/processing/custom_clip_utils.py +++ b/src/marqo/s2_inference/processing/custom_clip_utils.py @@ -7,6 +7,16 @@ import urllib from tqdm import tqdm from marqo.s2_inference.configs import ModelCache +from typing import Optional +from urllib.error import HTTPError +from marqo.s2_inference.errors import ModelDownloadError, InvalidModelPropertiesError +from marqo.tensor_search.models.private_models import ModelAuth, ModelLocation +from marqo.s2_inference.model_downloading.from_s3 import ( + get_presigned_s3_url, get_s3_model_cache_filename, check_s3_model_already_exists, + get_s3_model_absolute_cache_path +) +from marqo.s2_inference.model_downloading.from_hf import download_model_from_hf +from marqo.tensor_search.models.external_apis.s3 import S3Auth, S3Location def whitespace_clean(text): @@ -37,16 +47,98 @@ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch. return input_ids +def download_model( + repo_location: Optional[ModelLocation] = None, + url: Optional[str] = None, + auth: Optional[ModelAuth] = None, + download_dir: Optional[str] = None + ): + """Downloads a custom CLIP model. + + Args: + repo_location: object that contains information about the location of a + model. For example, s3 bucket and object path + url: location of a model specified by a URL + auth: object that contains information about authorisation required to + download a model. For example, s3 access keys + download_dir: The directory where the model should be downloaded. + + Returns: + The path of the downloaded model + """ + single_weight_location_validation_msg = ( + "only exactly one of parameters (repo_location, url) is allowed to be specified.") + if repo_location is None and url is None: + raise InvalidModelPropertiesError(single_weight_location_validation_msg) + if repo_location is not None and url is not None: + raise InvalidModelPropertiesError(single_weight_location_validation_msg) + + if url: + return download_pretrained_from_url(url=url, cache_dir=download_dir) + + if repo_location.s3: + download_kwargs = {'location': repo_location.s3, 'download_dir': download_dir} + if auth is not None: + download_kwargs['auth'] = auth.s3 + return download_pretrained_from_s3(**download_kwargs) + elif repo_location.hf: + download_kwargs = {'location': repo_location.hf, 'download_dir': download_dir} + if auth is not None: + download_kwargs['auth'] = auth.hf + return download_model_from_hf(**download_kwargs) + + + +def download_pretrained_from_s3( + location: S3Location, + auth: Optional[S3Auth] = None, + download_dir: Optional[str] = None): + """Downloads a pretrained model from S3, if it doesn't exist locally. The basename of the object's + key is used for the filename. + + Args: + location: Bucket and key of model file to be downloaded + auth: AWS IAM access keys to a user with access to the model to be downloaded + download_dir: the location where the model should be stored + + Returns: + Path to the downloaded model + """ + + if check_s3_model_already_exists(location=location): + # TODO: check if abs path is even the most appropriate??? + return get_s3_model_absolute_cache_path(location=location) + + url = get_presigned_s3_url(location=location, auth=auth) + + try: + return download_pretrained_from_url( + url=url, cache_dir=download_dir, + cache_file_name=get_s3_model_cache_filename(location) + ) + except HTTPError as e: + if e.code == 403: + # TODO: add link to auth docs + raise ModelDownloadError( + "Received 403 error when trying to retrieve model from s3 storage. " + "Please check the request's s3 credentials and try again. " + ) + + def download_pretrained_from_url( url: str, cache_dir: Union[str, None] = None, + cache_file_name: Optional[str] = None ): ''' - This function takes a clip model checkpoint url as input, downloads the model, and returns the local - path of the downloaded file. + This function takes a clip model checkpoint url as input, downloads the model if it doesn't exist locally, + and returns the local path of the downloaded file. + Args: url: a valid string of the url address. cache_dir: the directory to store the file + cache_file_name: name of the model file when it gets downloaded to the cache. + If not provided, the basename of the URL is used. Returns: download_target: the local path of the downloaded file. ''' @@ -54,7 +146,11 @@ def download_pretrained_from_url( if not cache_dir: cache_dir = os.path.expanduser(ModelCache.clip_cache_path) os.makedirs(cache_dir, exist_ok=True) - filename = os.path.basename(url) + + if cache_file_name is None: + filename = os.path.basename(url) + else: + filename = cache_file_name download_target = os.path.join(cache_dir, filename) diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py index 8d5f7f41d..56a9f6dd1 100644 --- a/src/marqo/s2_inference/s2_inference.py +++ b/src/marqo/s2_inference/s2_inference.py @@ -1,11 +1,11 @@ """This is the interface for interacting with S2 Inference The functions defined here would have endpoints, later on. """ -import functools import numpy as np from marqo.errors import ModelCacheManagementError -from marqo.s2_inference.errors import (VectoriseError, InvalidModelPropertiesError, ModelLoadError, - UnknownModelError, ModelNotInCacheError) +from marqo.s2_inference.errors import ( + VectoriseError, InvalidModelPropertiesError, ModelLoadError, + UnknownModelError, ModelNotInCacheError, ModelDownloadError) from PIL import UnidentifiedImageError from marqo.s2_inference.model_registry import load_model_properties from marqo.s2_inference.configs import get_default_device, get_default_normalization, get_default_seq_length @@ -17,6 +17,7 @@ from marqo.tensor_search.utils import read_env_vars_and_defaults from marqo.tensor_search.enums import AvailableModelsKey from marqo.tensor_search.configs import EnvVars +from marqo.tensor_search.models.private_models import ModelAuth import threading from marqo.tensor_search.utils import read_env_vars_and_defaults, generate_batches from marqo.tensor_search.configs import EnvVars @@ -34,7 +35,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: dict = None, device: str = get_default_device(), normalize_embeddings: bool = get_default_normalization(), - **kwargs) -> List[List[float]]: + model_auth: ModelAuth = None, **kwargs) -> List[List[float]]: """vectorizes the content by model name Args: @@ -45,6 +46,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: if model_properties['name'] is not in model_registry, these properties are used to fetch the model if model_properties['name'] is in model_registry, default properties are overridden model_properties can be None only if model_name is a model present in the registry + model_auth: Authorisation details for downloading a model (if required) Returns: List[List[float]]: _description_ @@ -56,7 +58,10 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: validated_model_properties = _validate_model_properties(model_name, model_properties) model_cache_key = _create_model_cache_key(model_name, device, validated_model_properties) - _update_available_models(model_cache_key, model_name, validated_model_properties, device, normalize_embeddings) + _update_available_models( + model_cache_key, model_name, validated_model_properties, device, normalize_embeddings, + model_auth=model_auth + ) try: if isinstance(content, str): @@ -125,8 +130,7 @@ def _create_model_cache_key(model_name: str, device: str, model_properties: dict def _update_available_models(model_cache_key: str, model_name: str, validated_model_properties: dict, - device: str, - normalize_embeddings: bool) -> None: + device: str, normalize_embeddings: bool, model_auth: ModelAuth = None) -> None: """loads the model if it is not already loaded. Note this method assume the model_properties are validated. """ @@ -142,17 +146,25 @@ def _update_available_models(model_cache_key: str, model_name: str, validated_mo calling_func=_update_available_models.__name__) try: most_recently_used_time = datetime.datetime.now() - available_models[model_cache_key] = {AvailableModelsKey.model: _load_model(model_name, - validated_model_properties, - device=device, - calling_func = _update_available_models.__name__), - AvailableModelsKey.most_recently_used_time: most_recently_used_time, - AvailableModelsKey.model_size: model_size} + available_models[model_cache_key] = { + AvailableModelsKey.model: _load_model( + model_name, validated_model_properties, + device=device, + calling_func=_update_available_models.__name__, + model_auth=model_auth + ), + AvailableModelsKey.most_recently_used_time: most_recently_used_time, + AvailableModelsKey.model_size: model_size + } logger.info( f'loaded {model_name} on device {device} with normalization={normalize_embeddings} at time={most_recently_used_time}.') except Exception as e: logger.error(f"Error loading model {model_name} on device {device} with normalization={normalize_embeddings}. \n" f"Error message is {str(e)}") + + if isinstance(e, ModelDownloadError): + raise e + raise ModelLoadError( f"Unable to load model={model_name} on device={device} with normalization={normalize_embeddings}. " f"If you are trying to load a custom model, " @@ -300,13 +312,17 @@ def get_model_size(model_name: str, model_properties: dict) -> (int, float): return constants.MODEL_TYPE_SIZE_MAPPING.get(type, constants.DEFAULT_MODEL_SIZE) -def _load_model(model_name: str, model_properties: dict, device: Optional[str] = None, calling_func: str = None) -> Any: +def _load_model( + model_name: str, model_properties: dict, device: Optional[str] = None, + calling_func: str = None, model_auth: ModelAuth = None +) -> Any: """_summary_ Args: model_name (str): Actual model_name to be fetched from external library prefer passing it in the form of model_properties['name'] device (str, optional): _description_. Defaults to 'cpu'. + model_auth: Authorisation details for downloading a model (if required) Returns: Any: _description_ @@ -321,8 +337,10 @@ def _load_model(model_name: str, model_properties: dict, device: Optional[str] = max_sequence_length = model_properties.get('tokens', get_default_seq_length()) - model = loader(model_properties['name'], device=device, embedding_dim=model_properties['dimensions'], - max_seq_length=max_sequence_length, model_properties=model_properties) + model = loader( + model_properties['name'], device=device, embedding_dim=model_properties['dimensions'], + max_seq_length=max_sequence_length, model_properties=model_properties, model_auth=model_auth + ) model.load() diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index 59687e063..2f61b0533 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -2,8 +2,8 @@ import typing from fastapi.responses import JSONResponse from fastapi import Request, Depends -import marqo.tensor_search.delete_docs -import marqo.tensor_search.tensor_search +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from marqo.tensor_search.models.add_docs_objects import ModelAuth from marqo.errors import InvalidArgError, MarqoWebError, MarqoError from fastapi import FastAPI, Query import json @@ -154,49 +154,60 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api image_download_headers=search_query.image_download_headers, context=search_query.context, score_modifiers=search_query.scoreModifiers, + model_auth=search_query.modelAuth ) @app.post("/indexes/{index_name}/documents") @throttle(RequestType.INDEX) -def add_or_replace_documents(docs: List[Dict], index_name: str, refresh: bool = True, - marqo_config: config.Config = Depends(generate_config), - batch_size: int = 0, processes: int = 1, - non_tensor_fields: List[str] = Query(default=[]), - device: str = Depends(api_validation.validate_device), - use_existing_tensors: bool = False, - image_download_headers: typing.Optional[dict] = Depends( - api_utils.decode_image_download_headers), - mappings: typing.Optional[dict] = Depends( - api_utils.decode_mappings) - ): +def add_or_replace_documents( + docs: List[Dict], + index_name: str, + refresh: bool = True, + marqo_config: config.Config = Depends(generate_config), + batch_size: int = 0, + processes: int = 1, + non_tensor_fields: List[str] = Query(default=[]), + device: str = Depends(api_validation.validate_device), + use_existing_tensors: bool = False, + image_download_headers: typing.Optional[dict] = Depends( + api_utils.decode_image_download_headers + ), + model_auth: typing.Optional[ModelAuth] = Depends( + api_utils.decode_query_string_model_auth + ), + mappings: typing.Optional[dict] = Depends(api_utils.decode_mappings)): """add_documents endpoint (replace existing docs with the same id)""" + add_docs_params = AddDocsParams( + index_name=index_name, docs=docs, auto_refresh=refresh, + device=device, update_mode='replace', non_tensor_fields=non_tensor_fields, + use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, + mappings=mappings, model_auth=model_auth + ) return tensor_search.add_documents_orchestrator( - config=marqo_config, - docs=docs, - index_name=index_name, auto_refresh=refresh, - batch_size=batch_size, processes=processes, device=device, - non_tensor_fields=non_tensor_fields, update_mode='replace', - image_download_headers=image_download_headers, - use_existing_tensors=use_existing_tensors, - mappings=mappings + config=marqo_config, add_docs_params=add_docs_params, + batch_size=batch_size, processes=processes ) @app.put("/indexes/{index_name}/documents") @throttle(RequestType.INDEX) -def add_or_update_documents(docs: List[Dict], index_name: str, refresh: bool = True, - marqo_config: config.Config = Depends(generate_config), - batch_size: int = 0, processes: int = 1, - non_tensor_fields: List[str] = Query(default=[]), - device: str = Depends(api_validation.validate_device)): +def add_or_update_documents( + docs: List[Dict], + index_name: str, + refresh: bool = True, + marqo_config: config.Config = Depends(generate_config), + batch_size: int = 0, processes: int = 1, + non_tensor_fields: List[str] = Query(default=[]), + device: str = Depends(api_validation.validate_device)): """WILL BE DEPRECATED SOON. update add_documents endpoint""" + add_docs_params = AddDocsParams( + index_name=index_name, docs=docs, auto_refresh=refresh, + device=device, update_mode='update', non_tensor_fields=non_tensor_fields + ) return tensor_search.add_documents_orchestrator( - config=marqo_config, - docs=docs, - index_name=index_name, auto_refresh=refresh, - batch_size=batch_size, processes=processes, device=device, - non_tensor_fields=non_tensor_fields, update_mode='update' + config=marqo_config, add_docs_params=add_docs_params, + batch_size=batch_size, processes=processes, ) @app.get("/indexes/{index_name}/documents/{document_id}") diff --git a/src/marqo/tensor_search/enums.py b/src/marqo/tensor_search/enums.py index 30ed29ebc..dab6fa37a 100644 --- a/src/marqo/tensor_search/enums.py +++ b/src/marqo/tensor_search/enums.py @@ -131,6 +131,18 @@ class AvailableModelsKey: model_size = "model_size" +class ObjectStores: + s3 = 's3' + hf = 'hf' + + +class ModelProperties: + auth_required = 'auth_required' + model_location = 'model_location' + + +class InferenceParams: + model_auth = "model_auth" # Perhaps create a ThrottleType to differentiate thread_count and data_size throttling mechanisms diff --git a/src/marqo/tensor_search/models/add_docs_objects.py b/src/marqo/tensor_search/models/add_docs_objects.py new file mode 100644 index 000000000..57ccc375b --- /dev/null +++ b/src/marqo/tensor_search/models/add_docs_objects.py @@ -0,0 +1,46 @@ +from pydantic.dataclasses import dataclass +from pydantic import Field +from typing import Optional, Union, Any, Sequence +import numpy as np +from marqo.tensor_search.models.private_models import ModelAuth +from typing import List + + +class AddDocsParamsConfig: + arbitrary_types_allowed = True + + +@dataclass(frozen=True, config=AddDocsParamsConfig) +class AddDocsParams: + """Represents the parameters of the tensor_search.add_documents() function + + Params: + index_name: name of the index + docs: List of documents + auto_refresh: Set to False if indexing lots of docs + non_tensor_fields: List of fields, within documents to not create tensors for. Default to + make tensors for all fields. + use_existing_tensors: Whether to use the vectors already in doc (for update docs) + device: Device used to carry out the document update. + update_mode: {'replace' | 'update'}. If set to replace (default) just + image_download_thread_count: number of threads used to concurrently download images + image_download_headers: headers to authenticate image download + mappings: a dictionary used to handle all the object field content in the doc, + e.g., multimodal_combination field + model_auth: an object used to authorise downloading an object from a datastore + + """ + + # this should only accept Sequences of dicts, but currently validation lies elsewhere + docs: Union[Sequence[Union[dict, Any]], np.ndarray] + + index_name: str + auto_refresh: bool + non_tensor_fields: List = Field(default_factory=list) + device: Optional[str] = None + update_mode: str = "replace" + image_download_thread_count: int = 20 + image_download_headers: dict = Field(default_factory=dict) + use_existing_tensors: bool = False + mappings: Optional[dict] = None + model_auth: Optional[ModelAuth] = None diff --git a/src/marqo/tensor_search/models/api_models.py b/src/marqo/tensor_search/models/api_models.py index bc239cee5..aa60cb298 100644 --- a/src/marqo/tensor_search/models/api_models.py +++ b/src/marqo/tensor_search/models/api_models.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from typing import Union, List, Dict, Optional, Any from marqo.tensor_search.enums import SearchMethod, Device +from marqo.tensor_search.models.private_models import ModelAuth from marqo.tensor_search import validation class BaseMarqoModel(BaseModel): @@ -28,6 +29,7 @@ class SearchQuery(BaseMarqoModel): image_download_headers: Optional[Dict] = None context: Optional[Dict] = None scoreModifiers: Optional[Dict] = None + modelAuth: Optional[ModelAuth] = None @pydantic.validator('searchMethod') def validate_search_method(cls, value): diff --git a/src/marqo/tensor_search/models/external_apis/abstract_classes.py b/src/marqo/tensor_search/models/external_apis/abstract_classes.py new file mode 100644 index 000000000..0e4217942 --- /dev/null +++ b/src/marqo/tensor_search/models/external_apis/abstract_classes.py @@ -0,0 +1,18 @@ +""" +These are abstract classes that shouldn't be instantiated +""" +from pydantic import BaseModel + +class ExternalAuth(BaseModel): + """Authentication used to download an object + """ + class Config: + allow_mutation = False + +class ObjectLocation(BaseModel): + """Reference to an object location (for example a pointer to a model file + in s3 + """ + class Config: + allow_mutation = False + diff --git a/src/marqo/tensor_search/models/external_apis/hf.py b/src/marqo/tensor_search/models/external_apis/hf.py new file mode 100644 index 000000000..f6ee958b0 --- /dev/null +++ b/src/marqo/tensor_search/models/external_apis/hf.py @@ -0,0 +1,16 @@ +from pydantic.dataclasses import dataclass +from marqo.tensor_search.models.external_apis.abstract_classes import ( + ObjectLocation, ExternalAuth +) + + +class HfAuth(ExternalAuth): + token: str + + +class HfModelLocation(ObjectLocation): + repo_id: str + filename: str + + + diff --git a/src/marqo/tensor_search/models/external_apis/s3.py b/src/marqo/tensor_search/models/external_apis/s3.py new file mode 100644 index 000000000..cece21e3d --- /dev/null +++ b/src/marqo/tensor_search/models/external_apis/s3.py @@ -0,0 +1,17 @@ +from pydantic.dataclasses import dataclass +from typing import Optional +from marqo.tensor_search.models.external_apis.abstract_classes import ( + ObjectLocation, ExternalAuth +) + + +class S3Auth(ExternalAuth): + aws_secret_access_key: str + aws_access_key_id: str + aws_session_token: Optional[str] = None + + +class S3Location(ObjectLocation): + Bucket: str + Key: str + diff --git a/src/marqo/tensor_search/models/private_models.py b/src/marqo/tensor_search/models/private_models.py new file mode 100644 index 000000000..8bc954d34 --- /dev/null +++ b/src/marqo/tensor_search/models/private_models.py @@ -0,0 +1,60 @@ +"""This regards structuring information regrading customer-stored ML models + +For example models stored on custom Huggingface repos or on private s3 buckets +""" +from marqo.tensor_search.models.external_apis.hf import HfAuth, HfModelLocation +from marqo.tensor_search.models.external_apis.s3 import S3Auth, S3Location +from pydantic import BaseModel, validator +from marqo.errors import InvalidArgError +from typing import Optional + +class ModelAuth(BaseModel): + """TODO: insert links to docs in error message""" + class Config: + allow_mutation = False + + s3: Optional[S3Auth] = None + hf: Optional[HfAuth] = None + + def __init__(self, **data): + super().__init__(**data) + if self.s3 is None and self.hf is None: + raise InvalidArgError( + "Missing authentication object. An authentic object, for example `s3` or " + "`hf`, must be provided. ") + + @validator('s3', 'hf', pre=True, always=True) + def _ensure_exactly_one_auth_method(cls, v, values, field): + other_field = 's3' if field.name == 'hf' else 'hf' + if other_field in values and values[other_field] is not None and v is not None: + raise InvalidArgError( + "More than one model authentication was provided. " + "Only one model authentication object is allowed") + return v + + +class ModelLocation(BaseModel): + + class Config: + allow_mutation = False + + s3: Optional[S3Location] = None + hf: Optional[HfModelLocation] = None + auth_required: bool = False + + @validator('s3', 'hf', pre=True, always=True) + def _ensure_exactly_one_location(cls, v, values, field): + """TODO: insert links to docs in error message""" + other_field = 's3' if field.name == 'hf' else 'hf' + if other_field in values and values[other_field] is not None and v is not None: + raise InvalidArgError( + "More than one model location object was provided. " + "Only one model authentication object is allowed") + return v + + def __init__(self, **data): + super().__init__(**data) + if self.s3 is None and self.hf is None: + raise InvalidArgError( + "Missing model location object. A location object, for example `s3` or " + "`hf`, must be provided. ") \ No newline at end of file diff --git a/src/marqo/tensor_search/models/search.py b/src/marqo/tensor_search/models/search.py index 3f34680ab..b5ecc47ee 100644 --- a/src/marqo/tensor_search/models/search.py +++ b/src/marqo/tensor_search/models/search.py @@ -1,5 +1,6 @@ import json from pydantic import BaseModel +from marqo.tensor_search.models.private_models import ModelAuth from typing import Any, Union, List, Dict, Optional, NewType, Literal Qidx = NewType('Qidx', int) # Indicates the position of a search query in a bulk search request @@ -24,6 +25,7 @@ class VectorisedJobs(BaseModel): normalize_embeddings: bool image_download_headers: Optional[Dict] content_type: Literal['text', 'image'] + model_auth: Optional[ModelAuth] def __hash__(self): return self.groupby_key() + hash(json.dumps(self.content, sort_keys=True)) diff --git a/src/marqo/tensor_search/models/settings_object.py b/src/marqo/tensor_search/models/settings_object.py index aa15a82b8..12d40a365 100644 --- a/src/marqo/tensor_search/models/settings_object.py +++ b/src/marqo/tensor_search/models/settings_object.py @@ -1,5 +1,5 @@ from marqo.tensor_search import enums as ns_enums -from marqo.tensor_search.enums import IndexSettingsField as NsFields, EnvVars +from marqo.tensor_search.enums import IndexSettingsField as NsFields, EnvVars, ObjectStores from marqo.tensor_search.utils import read_env_vars_and_defaults, read_env_vars_and_defaults_ints settings_schema = { diff --git a/src/marqo/tensor_search/parallel.py b/src/marqo/tensor_search/parallel.py index bee13f002..4ef815e79 100644 --- a/src/marqo/tensor_search/parallel.py +++ b/src/marqo/tensor_search/parallel.py @@ -9,6 +9,10 @@ from marqo import errors from marqo.tensor_search import tensor_search from marqo.marqo_logging import logger +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from dataclasses import replace +from marqo.config import Config + try: mp.set_start_method('spawn', force=True) @@ -85,30 +89,22 @@ class IndexChunk: """wrapper to pass through documents to be indexed to multiprocessing """ - def __init__(self, config=None, index_name: str = None, docs: List[Dict] = [], - auto_refresh: bool = False, batch_size: int = 50, - device: str = None, process_id: int = 0, - non_tensor_fields: List[str] = [], - threads_per_process: int = None, update_mode: str = 'replace', - use_existing_tensors: bool = False, image_download_headers: Optional[Dict] = None, - mappings: dict = None): + def __init__( + self, + add_docs_params: AddDocsParams, + config: Config, + batch_size: int = 50, + process_id: int = 0, + threads_per_process: int = None): self.config = copy.deepcopy(config) - self.index_name = index_name - self.docs = docs - self.auto_refresh = auto_refresh + self.add_docs_params = add_docs_params self.n_batch = batch_size - self.n_docs = len(docs) + self.n_docs = len(add_docs_params.docs) self.n_chunks = max(1, self.n_docs // self.n_batch) - self.device = device self.process_id = process_id - self.update_mode = update_mode - self.config.indexing_device = device if device is not None else self.config.indexing_device + self.config.indexing_device = add_docs_params.device if add_docs_params.device is not None else self.config.indexing_device self.threads_per_process = threads_per_process - self.non_tensor_fields = non_tensor_fields - self.use_existing_tensors = use_existing_tensors - self.image_download_headers = image_download_headers - self.mappings = mappings def process(self): @@ -117,7 +113,7 @@ def process(self): logger.info(f'starting add documents using {self.n_chunks} chunks per process...') - if self.device.startswith('cpu') and self.threads_per_process is not None: + if self.add_docs_params.device.startswith('cpu') and self.threads_per_process is not None: logger.info(f"restricting threads to {self.threads_per_process} for process={self.process_id}") torch.set_num_threads(self.threads_per_process) @@ -127,18 +123,16 @@ def process(self): total_progress_displays = 10 progress_display_frequency = max(1, self.n_chunks // total_progress_displays) - for n_processed,_doc in enumerate(np.array_split(self.docs, self.n_chunks)): + for n_processed,_doc in enumerate(np.array_split(self.add_docs_params.docs, self.n_chunks)): t_chunk_start = time.time() percent_done = self._calculate_percent_done(n_processed + 1, self.n_chunks) if n_processed % progress_display_frequency == 0: - logger.info(f'process={self.process_id} completed={percent_done}/100% on device={self.device}') + logger.info( + f'process={self.process_id} completed={percent_done}/100% on device={self.add_docs_params.device}') results.append(tensor_search.add_documents( - config=self.config, index_name=self.index_name, docs=_doc, auto_refresh=self.auto_refresh, - update_mode=self.update_mode, non_tensor_fields=self.non_tensor_fields, - use_existing_tensors=self.use_existing_tensors, image_download_headers=self.image_download_headers, - mappings=self.mappings + config=self.config, add_docs_params=self.add_docs_params )) t_chunk_end = time.time() @@ -170,22 +164,18 @@ def get_threads_per_process(processes: int): total_cpu = max(1, mp.cpu_count() - 2) return max(1, total_cpu//processes) -def add_documents_mp(config=None, index_name=None, docs=None, - auto_refresh=None, batch_size=50, processes=1, device=None, - non_tensor_fields: List[str] = [], update_mode: str = None, - image_download_headers: Optional[Dict] = None, use_existing_tensors=None, - mappings: dict = None - ): +def add_documents_mp( + add_docs_params: AddDocsParams, + config: Config, + batch_size=50, + processes=1 + ): """add documents using parallel processing using ray Args: - documents (_type_): _description_ - config (_type_, optional): _description_. Defaults to None. - index_name (_type_, optional): _description_. Defaults to None. - auto_refresh (_type_, optional): _description_. Defaults to None. - non_tensor_fields (_type, List[str]): _description_. Fields within documents not to create - tensors for. Defaults to create tensors for all fields. - update_mode (str, optional): - use_existing_tensors + add_docs_params: parameters used by the add_docs call + config: Marqo configuration object + batch_size: size of batch to be processed and sent to Marqo-os + processes: number of processes to use Assumes running on the same host right now. Ray or something else should be used if the processing is distributed. @@ -193,11 +183,10 @@ def add_documents_mp(config=None, index_name=None, docs=None, Returns: _type_: _description_ """ - if image_download_headers is None: - image_download_headers = dict() - selected_device = device if device is not None else config.indexing_device - n_documents = len(docs) + selected_device = add_docs_params.device if add_docs_params.device is not None else config.indexing_device + + n_documents = len(add_docs_params.docs) logger.info(f"found {n_documents} documents") @@ -214,12 +203,12 @@ def add_documents_mp(config=None, index_name=None, docs=None, start = time.time() - chunkers = [IndexChunk( - config=config, index_name=index_name, docs=_docs, non_tensor_fields=non_tensor_fields, - auto_refresh=auto_refresh, batch_size=batch_size, update_mode=update_mode, - use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, - process_id=p_id, device=device_ids[p_id], threads_per_process=threads_per_process, mappings=mappings) - for p_id,_docs in enumerate(np.array_split(docs, n_processes))] + chunkers = [ + IndexChunk( + config=config, batch_size=batch_size, + process_id=p_id, threads_per_process=threads_per_process, + add_docs_params=replace(add_docs_params, docs=_docs, device=device_ids[p_id])) + for p_id,_docs in enumerate(np.array_split(add_docs_params.docs, n_processes))] logger.info(f'Performing parallel now across devices {device_ids}...') with mp.Pool(n_processes) as pool: results = pool.map(_run_chunker, chunkers) diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index dc2c0ff0c..5dc1bad9b 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -1,13 +1,13 @@ -"""tensor search logic. In the future this will be accessible to the client via an API +"""tensor search logic API Notes: - - Fields beginning with a double underscore "__" are protected and used for our internal purposes. + - Some fields beginning with a double underscore "__" are protected and used for our internal purposes. - Examples include: - __embedding_vector __field_name __field_content __doc_chunk_relation __chunk_ids + fields beginning with "__vector_" - The "_id" field isn't a real field. It's a way to declare an ID. Internally we use it as the ID for the doc. The doc is stored without this field in its body @@ -37,8 +37,10 @@ import functools import pprint import typing +from marqo.tensor_search.models.private_models import ModelAuth import uuid from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple +from marqo.tensor_search.models.add_docs_objects import AddDocsParams import numpy as np from PIL import Image import marqo.config as config @@ -55,6 +57,7 @@ from marqo.tensor_search.models.api_models import BulkSearchQuery, BulkSearchQueryEntity from marqo.tensor_search.models.search import VectorisedJobs, VectorisedJobPointer, Qidx, JHash from marqo.tensor_search.models.index_info import IndexInfo +from marqo.tensor_search.models.external_apis.abstract_classes import ExternalAuth from marqo.tensor_search.utils import add_timing from marqo.tensor_search import delete_docs from marqo.s2_inference.processing import text as text_processor @@ -71,7 +74,7 @@ from marqo import errors from marqo.s2_inference import errors as s2_inference_errors import threading - +from dataclasses import replace from marqo.tensor_search.tensor_search_logging import get_logger logger = get_logger(__name__) @@ -195,26 +198,12 @@ def _autofill_index_settings(index_settings: dict): copied_settings = utils.merge_dicts(default_settings, copied_settings) - # if NsField.index_defaults not in copied_settings: - # copied_settings[NsField.index_defaults] = default_settings[NsField.index_defaults] - if NsField.treat_urls_and_pointers_as_images in copied_settings[NsField.index_defaults] and \ copied_settings[NsField.index_defaults][NsField.treat_urls_and_pointers_as_images] is True \ and copied_settings[NsField.index_defaults][NsField.model] is None: copied_settings[NsField.index_defaults][NsField.model] = MlModel.clip - # make sure the first level of keys are present, if not add all of those defaults - # for key in list(default_settings): - # if key not in copied_settings or copied_settings[key] is None: - # copied_settings[key] = default_settings[key] - - # # make sure the first level of keys in index defaults is present, if not add all of those defaults - # for key in list(default_settings[NsField.index_defaults]): - # if key not in copied_settings[NsField.index_defaults] or \ - # copied_settings[NsField.index_defaults][key] is None: - # copied_settings[NsField.index_defaults][key] = default_settings[NsField.index_defaults][key] - - # text preprocessing sub fields - fills any missing sub-dict fields if some of the first level are present + # text preprocessing subfields - fills any missing sub-dict fields if some of the first level are present for key in list(default_settings[NsField.index_defaults][NsField.text_preprocessing]): if key not in copied_settings[NsField.index_defaults][NsField.text_preprocessing] or \ copied_settings[NsField.index_defaults][NsField.text_preprocessing][key] is None: @@ -247,73 +236,57 @@ def _check_and_create_index_if_not_exist(config: Config, index_name: str): def add_documents_orchestrator( - config: Config, index_name: str, docs: List[dict], - auto_refresh: bool, batch_size: int = 0, processes: int = 1, - non_tensor_fields=None, image_download_headers: dict = None, - device=None, update_mode: str = 'replace', use_existing_tensors: bool = False, - mappings: dict = None + config: Config, add_docs_params: AddDocsParams, + batch_size: int = 0, processes: int = 1, ): - if image_download_headers is None: - image_download_headers = dict() - - if non_tensor_fields is None: - non_tensor_fields = [] if batch_size is None or batch_size == 0: logger.debug(f"batch_size={batch_size} and processes={processes} - not doing any marqo side batching") - return add_documents( - config=config, index_name=index_name, docs=docs, auto_refresh=auto_refresh, - device=device, update_mode=update_mode, non_tensor_fields=non_tensor_fields, - use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, - mappings=mappings - ) + return add_documents(config=config, add_docs_params=add_docs_params) elif processes is not None and processes > 1: - # create beforehand or pull from the cache so it is upto date for the multi-processing - _check_and_create_index_if_not_exist(config=config, index_name=index_name) + # create beforehand or pull from the cache so it is up to date for the multi-processing + _check_and_create_index_if_not_exist(config=config, index_name=add_docs_params.index_name) + + try: + _vector_text_search( + config=config, index_name=add_docs_params.index_name, query='', + model_auth=add_docs_params.model_auth, + image_download_headers=add_docs_params.image_download_headers) + except Exception as e: + logger.warning( + f"add_documents orchestrator's call to _vector_text_search, prior to parallel add_docs, raised an error. " + f"Continuing to parallel add_docs. " + f"Message: {e}" + ) logger.debug(f"batch_size={batch_size} and processes={processes} - using multi-processing") results = parallel.add_documents_mp( - config=config, index_name=index_name, docs=docs, - auto_refresh=auto_refresh, batch_size=batch_size, processes=processes, - device=device, update_mode=update_mode, non_tensor_fields=non_tensor_fields, - use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, - mappings=mappings + config=config, batch_size=batch_size, processes=processes, add_docs_params=add_docs_params ) - # we need to force the cache to update as it does not propagate using mp # we just clear this index's entry and it will re-populate when needed next - if index_name in get_cache(): - logger.info(f'deleting cache entry for {index_name} after parallel add documents') - del get_cache()[index_name] + if add_docs_params.index_name in get_cache(): + logger.info(f'deleting cache entry for {add_docs_params.index_name} after parallel add documents') + del get_cache()[add_docs_params.index_name] return results else: if batch_size < 0: raise errors.InvalidArgError("Batch size can't be less than 1!") logger.debug(f"batch_size={batch_size} and processes={processes} - batching using a single process") - return _batch_request(config=config, index_name=index_name, dataset=docs, device=device, - batch_size=batch_size, verbose=False, non_tensor_fields=non_tensor_fields, - use_existing_tensors=use_existing_tensors, - image_download_headers=image_download_headers, mappings=mappings) - - -def _batch_request(config: Config, index_name: str, dataset: List[dict], - batch_size: int = 100, verbose: bool = True, device=None, - update_mode: str = 'replace', non_tensor_fields=None, - image_download_headers: Optional[Dict] = None, use_existing_tensors: bool = False, - mappings: dict = None - ) -> List[Dict[str, Any]]: - """Batch by the number of documents""" - if image_download_headers is None: - image_download_headers = dict() + return _batch_request(config=config, verbose=False, add_docs_params=add_docs_params, batch_size=batch_size) + - if non_tensor_fields is None: - non_tensor_fields = [] +def _batch_request( + config: Config, add_docs_params: AddDocsParams, + verbose: bool = True, batch_size: int = 100 + ) -> List[Dict[str, Any]]: + """Batch by the number of documents""" logger.info(f"starting batch ingestion in sizes of {batch_size}") - deeper = ((doc, i, batch_size) for i, doc in enumerate(dataset)) + deeper = ((doc, i, batch_size) for i, doc in enumerate(add_docs_params.docs)) def batch_requests(gathered, doc_tuple): doc, i, the_batch_size = doc_tuple @@ -329,18 +302,14 @@ def verbosely_add_docs(i, docs): t0 = timer() logger.debug(f" batch {i}: beginning ingestion. ") - res = add_documents( - config=config, index_name=index_name, - docs=docs, auto_refresh=False, device=device, - update_mode=update_mode, non_tensor_fields=non_tensor_fields, - use_existing_tensors=use_existing_tensors, image_download_headers=image_download_headers, - mappings=mappings - ) + # we need to just use docs in the batch + batch_add_docs_params = replace(add_docs_params, docs=docs) + res = add_documents(config=config, add_docs_params=batch_add_docs_params) total_batch_time = timer() - t0 num_docs = len(docs) logger.debug(f" batch {i}: ingested {num_docs} docs. Time taken: {(total_batch_time):.3f}. " - f"Average time per doc {(total_batch_time / num_docs):.3f}") + f"Average time per doc {(total_batch_time / num_docs):.3f}") if verbose: logger.debug(f" results from indexing batch {i}: {res}") return res @@ -378,61 +347,44 @@ def _get_chunks_for_field(field_name: str, doc_id: str, doc): return [chunk for chunk in doc["_source"]["__chunks"] if chunk["__field_name"] == field_name] -def add_documents(config: Config, index_name: str, docs: List[dict], auto_refresh: bool, - non_tensor_fields=None, device=None, update_mode: str = "replace", - image_download_thread_count: int = 20, image_download_headers: dict = None, - use_existing_tensors: bool = False, mappings: dict = None): +def add_documents(config: Config, add_docs_params: AddDocsParams): """ Args: config: Config object - index_name: name of the index - docs: List of documents - auto_refresh: Set to False if indexing lots of docs - non_tensor_fields: List of fields, within documents to not create tensors for. Default to - make tensors for all fields. - use_existing_tensors: Whether or not to use the vectors already in doc (for update docs) - device: Device used to carry out the document update. - update_mode: {'replace' | 'update'}. If set to replace (default) just - image_download_thread_count: number of threads used to concurrently download images - image_download_headers: headers to authenticate image download - mappings: a dictionary used to handle all the object field content in the doc, e.g., multimodal_combination field + add_docs_params: add_documents()'s parameters Returns: """ # ADD DOCS TIMER-LOGGER (3) - if image_download_headers is None: - image_download_headers = dict() - start_time_3 = timer() - if non_tensor_fields is None: - non_tensor_fields = [] + start_time_3 = timer() - if mappings is not None: - validation.validate_mappings_object(mappings_object=mappings) + if add_docs_params.mappings is not None: + validation.validate_mappings_object(mappings_object=add_docs_params.mappings) t0 = timer() bulk_parent_dicts = [] try: - index_info = backend.get_index_info(config=config, index_name=index_name) + index_info = backend.get_index_info(config=config, index_name=add_docs_params.index_name) except errors.IndexNotFoundError as s: - create_vector_index(config=config, index_name=index_name) - index_info = backend.get_index_info(config=config, index_name=index_name) + create_vector_index(config=config, index_name=add_docs_params.index_name) + index_info = backend.get_index_info(config=config, index_name=add_docs_params.index_name) - if len(docs) == 0: + if len(add_docs_params.docs) == 0: raise errors.BadRequestError(message="Received empty add documents request") - if use_existing_tensors and update_mode != "replace": + if add_docs_params.use_existing_tensors and add_docs_params.update_mode != "replace": raise errors.InvalidArgError("use_existing_tensors=True is only available for add and replace documents," "not for add and update!") valid_update_modes = ('update', 'replace') - if update_mode not in valid_update_modes: - raise errors.InvalidArgError(message=f"Unknown update_mode `{update_mode}` " + if add_docs_params.update_mode not in valid_update_modes: + raise errors.InvalidArgError(message=f"Unknown update_mode `{add_docs_params.update_mode}` " f"received! Valid update modes: {valid_update_modes}") - if mappings is not None: - validate_mappings = validation.validate_mappings(mappings) + if add_docs_params.mappings is not None: + validate_mappings = validation.validate_mappings(add_docs_params.mappings) existing_fields = set(index_info.properties.keys()) new_fields = set() @@ -443,32 +395,35 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres # Check backend to see the differences between multimodal_fields and new_fields new_obj_fields = dict() - selected_device = config.indexing_device if device is None else device + selected_device = config.indexing_device if add_docs_params.device is None else add_docs_params.device unsuccessful_docs = [] total_vectorise_time = 0 - batch_size = len(docs) + batch_size = len(add_docs_params.docs) image_repo = {} if index_info.index_settings[NsField.index_defaults][NsField.treat_urls_and_pointers_as_images]: ti_0 = timer() - image_repo = add_docs.download_images(docs=docs, thread_count=20, non_tensor_fields=tuple(non_tensor_fields), - image_download_headers=image_download_headers) + image_repo = add_docs.download_images(docs=add_docs_params.docs, thread_count=20, + non_tensor_fields=tuple(add_docs_params.non_tensor_fields), + image_download_headers=add_docs_params.image_download_headers) logger.debug(f" add_documents image download: took {(timer() - ti_0):.3f}s to concurrently download " - f"images for {batch_size} docs using {image_download_thread_count} threads ") + f"images for {batch_size} docs using {add_docs_params.image_download_thread_count} threads ") - if update_mode == 'replace' and use_existing_tensors: + if add_docs_params.update_mode == 'replace' and add_docs_params.use_existing_tensors: doc_ids = [] # Iterate through the list in reverse, only latest doc with dupe id gets added. - for i in range(len(docs)-1, -1, -1): - if ("_id" in docs[i]) and (docs[i]["_id"] not in doc_ids): - doc_ids.append(docs[i]["_id"]) - existing_docs = _get_documents_for_upsert(config=config, index_name=index_name, document_ids=doc_ids) + for i in range(len(add_docs_params.docs)-1, -1, -1): + if ("_id" in add_docs_params.docs[i]) and (add_docs_params.docs[i]["_id"] not in doc_ids): + doc_ids.append(add_docs_params.docs[i]["_id"]) + existing_docs = _get_documents_for_upsert( + config=config, index_name=add_docs_params.index_name, document_ids=doc_ids) - for i, doc in enumerate(docs): + for i, doc in enumerate(add_docs_params.docs): - indexing_instructions = {'index' if update_mode == 'replace' else 'update': {"_index": index_name}} + indexing_instructions = { + 'index' if add_docs_params.update_mode == 'replace' else 'update': {"_index": add_docs_params.index_name}} copied = copy.deepcopy(doc) document_is_valid = True @@ -492,9 +447,9 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres ) continue - if update_mode == "replace": + if add_docs_params.update_mode == "replace": indexing_instructions["index"]["_id"] = doc_id - if use_existing_tensors: + if add_docs_params.use_existing_tensors: matching_doc = [doc for doc in existing_docs["docs"] if doc["_id"] == doc_id] # Should only have 1 result, as only 1 id matches if len(matching_doc) == 1: @@ -525,10 +480,12 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres try: field_content = validation.validate_field_content( - field_content=copied[field], is_non_tensor_field=field in non_tensor_fields) + field_content=copied[field], is_non_tensor_field=field in add_docs_params.non_tensor_fields) if isinstance(field_content, dict): - field_content = validation.validate_dict(field = field, field_content = field_content, - is_non_tensor_field=field in non_tensor_fields, mappings=mappings) + field_content = validation.validate_dict( + field=field, field_content=field_content, + is_non_tensor_field=field in add_docs_params.non_tensor_fields, + mappings=add_docs_params.mappings) except errors.InvalidArgError as err: document_is_valid = False unsuccessful_docs.append( @@ -541,13 +498,14 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres new_fields_from_doc.add((field, _infer_opensearch_data_type(copied[field]))) # Don't process text/image fields when explicitly told not to. - if field in non_tensor_fields: + if field in add_docs_params.non_tensor_fields: continue # chunks generated by processing this field for this doc: chunks_to_append = [] # Check if content of this field changed. If no, skip all chunking and vectorisation - if ((update_mode == 'replace') and use_existing_tensors and existing_doc["found"] + if ((add_docs_params.update_mode == 'replace') and add_docs_params.use_existing_tensors + and existing_doc["found"] and (field in existing_doc["_source"]) and (existing_doc["_source"][field] == field_content)): chunks_to_append = _get_chunks_for_field(field_name=field, doc_id=doc_id, doc=existing_doc) @@ -625,13 +583,14 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=content_chunks, device=selected_device, normalize_embeddings=normalize_embeddings, - infer=infer_if_image) + infer=infer_if_image, model_auth=add_docs_params.model_auth) end_time = timer() total_vectorise_time += (end_time - start_time) except (s2_inference_errors.UnknownModelError, s2_inference_errors.InvalidModelPropertiesError, - s2_inference_errors.ModelLoadError) as model_error: + s2_inference_errors.ModelLoadError, + s2_inference.ModelDownloadError) as model_error: raise errors.BadRequestError( message=f'Problem vectorising query. Reason: {str(model_error)}', link="https://marqo.pages.dev/latest/Models-Reference/dense_retrieval/" @@ -660,12 +619,12 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres }) elif isinstance(field_content, dict): - if mappings[field]["type"]=="multimodal_combination": - combo_chunk, combo_document_is_valid, unsuccessful_doc_to_append, combo_vectorise_time_to_add,\ - new_fields_from_multimodal_combination= \ - vectorise_multimodal_combination_field(field, field_content, copied, - i, doc_id, selected_device, index_info, image_repo, mappings[field]) - + if add_docs_params.mappings[field]["type"] == "multimodal_combination": + (combo_chunk, combo_document_is_valid, + unsuccessful_doc_to_append, combo_vectorise_time_to_add, + new_fields_from_multimodal_combination) = vectorise_multimodal_combination_field( + field, field_content, copied, i, doc_id, selected_device, index_info, + image_repo, add_docs_params.mappings[field], model_auth=add_docs_params.model_auth) total_vectorise_time = total_vectorise_time + combo_vectorise_time_to_add if combo_document_is_valid is False: unsuccessful_docs.append(unsuccessful_doc_to_append) @@ -684,7 +643,7 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres if document_is_valid: new_fields = new_fields.union(new_fields_from_doc) - if update_mode == 'replace': + if add_docs_params.update_mode == 'replace': copied[TensorField.chunks] = chunks bulk_parent_dicts.append(indexing_instructions) bulk_parent_dicts.append(copied) @@ -742,7 +701,7 @@ def merged_doc = [:]; "doc_fields": list(copied.keys()), "new_chunks": chunks, "customer_dict": copied, - "non_tensor_fields": non_tensor_fields + "non_tensor_fields": add_docs_params.non_tensor_fields }, } }) @@ -758,7 +717,7 @@ def merged_doc = [:]; if bulk_parent_dicts: # the HttpRequest wrapper handles error logic update_mapping_response = backend.add_customer_field_properties( - config=config, index_name=index_name, customer_field_names=new_fields, + config=config, index_name=add_docs_params.index_name, customer_field_names=new_fields, model_properties=_get_model_properties(index_info), multimodal_combination_fields=new_obj_fields) # ADD DOCS TIMER-LOGGER (5) @@ -778,8 +737,8 @@ def merged_doc = [:]; else: index_parent_response = None - if auto_refresh: - refresh_response = HttpRequests(config).post(path=F"{index_name}/_refresh") + if add_docs_params.auto_refresh: + refresh_response = HttpRequests(config).post(path=F"{add_docs_params.index_name}/_refresh") t1 = timer() @@ -793,7 +752,7 @@ def translate_add_doc_response(response: Optional[dict], time_diff: float) -> di copied_res = copy.deepcopy(response) result_dict['errors'] = copied_res['errors'] - actioned = "index" if update_mode == 'replace' else 'update' + actioned = "index" if add_docs_params.update_mode == 'replace' else 'update' for item in copied_res["items"]: for to_remove in item_fields_to_remove: @@ -808,7 +767,7 @@ def translate_add_doc_response(response: Optional[dict], time_diff: float) -> di new_items.insert(loc, error_info) result_dict["processingTimeMs"] = time_diff * 1000 - result_dict["index_name"] = index_name + result_dict["index_name"] = add_docs_params.index_name result_dict["items"] = new_items return result_dict @@ -1069,7 +1028,8 @@ def search(config: Config, index_name: str, text: Union[str, dict], device=None, boost: Optional[Dict] = None, image_download_headers: Optional[Dict] = None, context: Optional[Dict] = None, - score_modifiers: Optional[Dict] = None) -> Dict: + score_modifiers: Optional[Dict] = None, + model_auth: Optional[ModelAuth] = None) -> Dict: """The root search method. Calls the specific search method Validation should go here. Validations include: @@ -1094,6 +1054,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], image_download_headers: headers for downloading images context: a dictionary to allow custom vectors in search, for tensor search only score_modifiers: a dictionary to modify the score based on field values, for tensor search only + model_auth: Authorisation details for downloading a model (if required) Returns: """ @@ -1146,7 +1107,8 @@ def search(config: Config, index_name: str, text: Union[str, dict], return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose, number_of_highlights=num_highlights, simplified_format=simplified_format, filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve, boost=boost, - image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers + image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers, + model_auth=model_auth ) elif search_method.upper() == SearchMethod.LEXICAL: search_result = _lexical_search( @@ -1494,7 +1456,8 @@ def assign_query_to_vector_job( device=device, normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'], image_download_headers=q.image_download_headers, - content_type=content_type + content_type=content_type, + model_auth=q.modelAuth ) # If exists, add content to vector job. Otherwise create new if jobs.get(vector_job.groupby_key()) is not None: @@ -1545,12 +1508,13 @@ def vectorise_jobs(jobs: List[VectorisedJobs]) -> Dict[JHash, Dict[str, List[flo model_name=v.model_name, model_properties=v.model_properties, content=v.content, device=v.device, normalize_embeddings=v.normalize_embeddings, - image_download_headers=v.image_download_headers + image_download_headers=v.image_download_headers, + model_auth=v.model_auth ) result[v.groupby_key()] = dict(zip(v.content, vectors)) - except s2_inference_errors.S2InferenceError: + except s2_inference_errors.S2InferenceError as e: # TODO: differentiate image processing errors from other types of vectorise errors - raise errors.InvalidArgError(message=f'Could not process given image in: {v.content}') + raise errors.InvalidArgError(message=f'Error vectorising content: {v.content}. Message: {e}') return result @@ -1749,7 +1713,9 @@ def _vector_text_search( attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None, image_download_headers: Optional[Dict] = None, context: Optional[Dict] = None, - score_modifiers: Optional[Dict] = None): + score_modifiers: Optional[Dict] = None, + model_auth: Optional[ModelAuth] = None + ): """ Args: config: @@ -1769,6 +1735,7 @@ def _vector_text_search( image_download_headers: headers for downloading images context: a dictionary to allow custom vectors in search score_modifiers: a dictionary to modify the score based on field values, for tensor search only + model_auth: Authorisation details for downloading a model (if required) Returns: Note: @@ -1829,7 +1796,8 @@ def _vector_text_search( model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=batch, device=selected_device, normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'], - image_download_headers=image_download_headers + image_download_headers=image_download_headers, + model_auth=model_auth ))) for batch in to_be_vectorised ] @@ -2342,9 +2310,11 @@ def get_cuda_info() -> dict: )) - -def vectorise_multimodal_combination_field(field: str, multimodal_object: Dict[str, dict], doc: dict, doc_index: int, - doc_id:str, selected_device:str, index_info, image_repo, field_map:dict): +def vectorise_multimodal_combination_field( + field: str, multimodal_object: Dict[str, dict], doc: dict, doc_index: int, + doc_id:str, selected_device:str, index_info, image_repo, field_map:dict, + model_auth: Optional[ModelAuth] = None +): ''' This function is used to vectorise multimodal combination field. The field content should have the following structure: @@ -2361,16 +2331,18 @@ def vectorise_multimodal_combination_field(field: str, multimodal_object: Dict[s doc_index: the index of the document. This is an interator variable `i` in the main body to iterator throught the docs doc_id: the document id selected_device: device from main body - index_info: index_info from main body + index_info: index_info from main body, + model_auth: Model download authorisation information (if required) Returns: - combo_chunk: the combo_chunk to be appended to the main body - combo_document_is_valid: if the document is a valid - unsuccessful_docs: appended unsucessful_docs - combo_total_vectorise_time: the vectorise time spent in combo field - new_fields_from_multimodal_combination: the new fields from multimodal combination field that will be added to index properties + combo_chunk: the combo_chunk to be appended to the main body + combo_document_is_valid: if the document is a valid + unsuccessful_docs: appended unsucessful_docs + combo_total_vectorise_time: the vectorise time spent in combo field + new_fields_from_multimodal_combination: the new fields from multimodal combination field that will be added to + index properties ''' - # field_conent = {"tensor_field_one" : {"weight":0.5, "parameter": "test-paramater-1"}, + # field_content = {"tensor_field_one" : {"weight":0.5, "parameter": "test-paramater-1"}, # "tensor_field_two" : {"weight": 0.5, parameter": "test-parameter-2"}}, combo_document_is_valid = True combo_vectorise_time_to_add = 0 @@ -2406,7 +2378,7 @@ def vectorise_multimodal_combination_field(field: str, multimodal_object: Dict[s else: try: if isinstance(sub_content, str) and index_info.index_settings[NsField.index_defaults][ - NsField.treat_urls_and_pointers_as_images]: + NsField.treat_urls_and_pointers_as_images]: if not isinstance(image_repo[sub_content], Exception): image_data = image_repo[sub_content] else: @@ -2438,14 +2410,14 @@ def vectorise_multimodal_combination_field(field: str, multimodal_object: Dict[s model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=text_content_to_vectorise, device=selected_device, normalize_embeddings=normalize_embeddings, - infer=infer_if_image) + infer=infer_if_image, model_auth=model_auth) image_vectors = [] if len(image_content_to_vectorise) > 0: image_vectors = s2_inference.vectorise( model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=image_content_to_vectorise, device=selected_device, normalize_embeddings=normalize_embeddings, - infer=infer_if_image) + infer=infer_if_image, model_auth=model_auth) end_time = timer() combo_vectorise_time_to_add += (end_time - start_time) except (s2_inference_errors.UnknownModelError, diff --git a/src/marqo/tensor_search/web/api_utils.py b/src/marqo/tensor_search/web/api_utils.py index fb1d9c357..7bf060a4e 100644 --- a/src/marqo/tensor_search/web/api_utils.py +++ b/src/marqo/tensor_search/web/api_utils.py @@ -4,7 +4,7 @@ from marqo.tensor_search import enums from typing import Optional from marqo.tensor_search.utils import construct_authorized_url -from marqo import config +from marqo.tensor_search.models.add_docs_objects import ModelAuth def upconstruct_authorized_url(opensearch_url: str) -> str: @@ -66,7 +66,7 @@ def translate_api_device(device: Optional[str]) -> Optional[str]: def decode_image_download_headers(image_download_headers: Optional[str] = None) -> dict: - """Decodes and image download header string into a Python dict + """Decodes an image download header string into a Python dict Args: image_download_headers: JSON-serialised, URL encoded header dictionary @@ -75,7 +75,7 @@ def decode_image_download_headers(image_download_headers: Optional[str] = None) image_download_headers as a dict Raises: - InvalidArgError is there is trouble parsing the dictionary + InvalidArgError if there is trouble parsing the dictionary """ if not image_download_headers: return dict() @@ -88,6 +88,26 @@ def decode_image_download_headers(image_download_headers: Optional[str] = None) raise InvalidArgError(f"Error parsing image_download_headers. Message: {e}") +def decode_query_string_model_auth(model_auth: Optional[str] = None) -> Optional[ModelAuth]: + """Decodes a url encoded ModelAuth string into a ModelAuth object + + Args: + model_auth: JSON-serialised, URL encoded ModelAuth dictionary + + Returns: + model_auth as a ModelAuth object, if found. Otherwise None + + Raises: + ValidationError if there is trouble parsing the string + """ + if not model_auth: + return None + else: + as_str = urllib.parse.unquote_plus(model_auth) + as_objc = ModelAuth.parse_raw(as_str) + return as_objc + + def decode_mappings(mappings: Optional[str] = None) -> dict: """Decodes mappings string into a Python dict diff --git a/src/marqo/version.py b/src/marqo/version.py index 1afbb29e7..293a218ba 100644 --- a/src/marqo/version.py +++ b/src/marqo/version.py @@ -1,4 +1,4 @@ -__version__ = "0.0.18" +__version__ = "0.0.19" def get_version() -> str: diff --git a/tests/s2_inference/model_downloading/__init__.py b/tests/s2_inference/model_downloading/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/s2_inference/model_downloading/test_from_hf.py b/tests/s2_inference/model_downloading/test_from_hf.py new file mode 100644 index 000000000..0e9c500ba --- /dev/null +++ b/tests/s2_inference/model_downloading/test_from_hf.py @@ -0,0 +1,63 @@ +import unittest +from unittest.mock import MagicMock, patch +from marqo.s2_inference.errors import ModelDownloadError +from marqo.tensor_search.models.external_apis.hf import HfAuth, HfModelLocation +from marqo.s2_inference.model_downloading.from_hf import download_model_from_hf +from huggingface_hub.utils._errors import RepositoryNotFoundError + + +class TestDownloadModelFromHF(unittest.TestCase): + def setUp(self): + self.hf_location = HfModelLocation(repo_id="test-repo-id", filename="test-filename") + self.hf_auth = HfAuth(token="test-token") + + def test_download_model_from_hf_success(self): + with patch("marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + return_value="model_path") as hf_hub_download_mock: + result = download_model_from_hf(self.hf_location, self.hf_auth) + self.assertEqual(result, "model_path") + hf_hub_download_mock.assert_called_once_with(repo_id="test-repo-id", filename="test-filename", token="test-token") + + def test_download_model_from_hf_no_auth(self): + with patch( + "marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + return_value="model_path") as hf_hub_download_mock: + result = download_model_from_hf(self.hf_location) + self.assertEqual(result, "model_path") + hf_hub_download_mock.assert_called_once_with(repo_id="test-repo-id", filename="test-filename") + + def test_download_model_from_hf_repository_not_found_error(self): + with patch("marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + side_effect=RepositoryNotFoundError("repo not found")): + with self.assertRaises(ModelDownloadError): + download_model_from_hf(self.hf_location, self.hf_auth) + + def test_download_model_from_hf_invalid_location(self): + invalid_location = HfModelLocation(repo_id="", filename="test-filename") + with patch("marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + side_effect=RepositoryNotFoundError("repo not found")): + with self.assertRaises(ModelDownloadError): + download_model_from_hf(invalid_location, self.hf_auth) + + def test_download_model_from_hf_invalid_auth(self): + invalid_auth = HfAuth(token="") + with patch("marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + side_effect=RepositoryNotFoundError("repo not found")): + with self.assertRaises(ModelDownloadError): + download_model_from_hf(self.hf_location, invalid_auth) + + def test_download_model_from_hf_unexpected_error(self): + with patch("marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + side_effect=Exception("Unexpected error")): + with self.assertRaises(Exception): + download_model_from_hf(self.hf_location, self.hf_auth) + + def test_download_model_from_hf_with_download_dir(self): + with patch("marqo.s2_inference.model_downloading.from_hf.hf_hub_download", + return_value="model_path") as hf_hub_download_mock: + with patch("marqo.s2_inference.model_downloading.from_hf.logger.warning") as logger_warning_mock: + result = download_model_from_hf(self.hf_location, self.hf_auth, download_dir="custom_download_dir") + self.assertEqual(result, "model_path") + hf_hub_download_mock.assert_called_once_with(repo_id="test-repo-id", filename="test-filename", token="test-token") + logger_warning_mock.assert_called_once() + diff --git a/tests/s2_inference/model_downloading/test_from_s3.py b/tests/s2_inference/model_downloading/test_from_s3.py new file mode 100644 index 000000000..161bcb93a --- /dev/null +++ b/tests/s2_inference/model_downloading/test_from_s3.py @@ -0,0 +1,73 @@ +from marqo.s2_inference.model_downloading.from_s3 import ( + get_presigned_s3_url, + get_s3_model_absolute_cache_path, + check_s3_model_already_exists, + get_s3_model_cache_filename, +) +from botocore.exceptions import NoCredentialsError +from marqo.s2_inference.configs import ModelCache +import unittest +import botocore +from unittest.mock import patch +from marqo.s2_inference.errors import ModelDownloadError +from marqo.tensor_search.models.external_apis.s3 import S3Auth, S3Location + + +class TestModelAuthEdgeCases(unittest.TestCase): + def setUp(self): + self.s3_location = S3Location(Bucket="test-bucket", Key="test-key") + self.s3_auth = S3Auth(aws_access_key_id="test-access-key", aws_secret_access_key="test-secret-key") + + def test_get_presigned_s3_url_no_credentials_error(self): + with patch("boto3.client") as boto3_client_mock: + boto3_client_mock.return_value.generate_presigned_url.side_effect = NoCredentialsError + with self.assertRaises(ModelDownloadError): + get_presigned_s3_url(self.s3_location, self.s3_auth) + + def test_get_presigned_s3_url_invalid_location(self): + invalid_location = S3Location(Bucket="", Key="") + with self.assertRaises(botocore.exceptions.ParamValidationError): + get_presigned_s3_url(invalid_location, self.s3_auth) + + def test_get_s3_model_absolute_cache_path_empty_key(self): + empty_key_location = S3Location(Bucket="test-bucket", Key="") + with patch("os.path.expanduser", return_value="some_cache_path"): + result = get_s3_model_absolute_cache_path(empty_key_location) + self.assertEqual(result, "some_cache_path/") + + def test_check_s3_model_already_exists_empty_key(self): + empty_key_location = S3Location(Bucket="test-bucket", Key="") + with patch("os.path.isfile", return_value=True): + result = check_s3_model_already_exists(empty_key_location) + self.assertTrue(result) + + def test_check_s3_model_already_exists_no_file(self): + with patch("os.path.isfile", return_value=False): + result = check_s3_model_already_exists(self.s3_location) + self.assertFalse(result) + + def test_get_s3_model_cache_filename_empty_key(self): + empty_key_location = S3Location(Bucket="test-bucket", Key="") + result = get_s3_model_cache_filename(empty_key_location) + self.assertEqual(result, "") + + def test_get_s3_model_absolute_cache_path_invalid_cache_dir(self): + with patch("os.path.expanduser", return_value=""): + result = get_s3_model_absolute_cache_path(self.s3_location) + self.assertEqual(result, "test-key") + + def test_get_s3_model_absolute_cache_path_cache_dir_not_expanded(self): + with patch("os.path.expanduser", side_effect=lambda x: x): + with patch("os.path.join", side_effect=lambda x, y: f"{x}/{y}"): + result = get_s3_model_absolute_cache_path(self.s3_location) + self.assertEqual(result, f"{ModelCache.clip_cache_path}/test-key") + + def test_check_s3_model_already_exists_os_error(self): + with patch("os.path.isfile", side_effect=OSError("Test OSError")): + with self.assertRaises(OSError): + check_s3_model_already_exists(self.s3_location) + + def test_get_s3_model_cache_filename_with_directory(self): + location_with_directory = S3Location(Bucket="test-bucket", Key="models/test-key") + result = get_s3_model_cache_filename(location_with_directory) + self.assertEqual(result, "test-key") diff --git a/tests/s2_inference/test_clip_utils.py b/tests/s2_inference/test_clip_utils.py index a7c664344..239779e4b 100644 --- a/tests/s2_inference/test_clip_utils.py +++ b/tests/s2_inference/test_clip_utils.py @@ -1,13 +1,19 @@ import copy import itertools - import PIL import requests.exceptions - from marqo.s2_inference import clip_utils, types import unittest from unittest import mock import requests +from marqo.s2_inference.clip_utils import CLIP, download_model, OPEN_CLIP +from marqo.tensor_search.enums import ModelProperties +from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth +from unittest.mock import patch +import pytest +from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth +from marqo.tensor_search.models.private_models import S3Auth, S3Location, HfModelLocation +from marqo.s2_inference.configs import ModelCache class TestEncoding(unittest.TestCase): @@ -71,3 +77,141 @@ def run(): return True run() + +class TestDownloadFromRepo(unittest.TestCase): + + @patch('marqo.s2_inference.clip_utils.download_model') + def test__download_from_repo_with_auth(self, mock_download_model, ): + mock_download_model.return_value = 'model.pth' + location = ModelLocation( + s3=S3Location(Bucket='some_bucket', Key='some_key'), auth_required=True) + s3_auth = S3Auth(aws_access_key_id='some_key_id', aws_secret_access_key='some_secret') + + model_props = { + ModelProperties.model_location: location.dict(), + } + auth = { + 's3': s3_auth.dict() + } + + clip = CLIP(model_properties=model_props, model_auth=auth) + assert clip._download_from_repo() == 'model.pth' + mock_download_model.assert_called_once_with(repo_location=location, auth=auth) + + @patch('marqo.s2_inference.clip_utils.download_model') + def test__download_from_repo_without_auth(self, mock_download_model, ): + mock_download_model.return_value = 'model.pth' + location = ModelLocation( + s3=S3Location(Bucket='some_bucket', Key='some_key'), auth_required=False) + + model_props = { + ModelProperties.model_location: location.dict(), + } + + clip = CLIP(model_properties=model_props) + assert clip._download_from_repo() == 'model.pth' + mock_download_model.assert_called_once_with(repo_location=location) + + @patch('marqo.s2_inference.clip_utils.download_model') + def test__download_from_repo_with_empty_filepath(self, mock_download_model): + mock_download_model.return_value = None + location = ModelLocation( + s3=S3Location(Bucket='some_bucket', Key='some_key'), auth_required=False) + + model_props = { + ModelProperties.model_location: location.dict(), + } + + clip = CLIP(model_properties=model_props) + + with pytest.raises(RuntimeError): + clip._download_from_repo() + + mock_download_model.assert_called_once_with(repo_location=location) + +class TestLoad(unittest.TestCase): + """tests the CLIP.load() method""" + @patch('marqo.s2_inference.clip_utils.clip.load', return_value=(mock.Mock(), mock.Mock())) + def test_load_without_model_properties(self, mock_clip_load): + clip = CLIP() + clip.load() + mock_clip_load.assert_called_once_with('ViT-B/32', device='cpu', jit=False, download_root=ModelCache.clip_cache_path) + + @patch('marqo.s2_inference.clip_utils.clip.load', return_value=(mock.Mock(), mock.Mock())) + @patch('os.path.isfile', return_value=True) + def test_load_with_local_file(self, mock_isfile, mock_clip_load): + model_path = 'localfile.pth' + clip = CLIP(model_properties={'localpath': model_path}) + clip.load() + mock_clip_load.assert_called_once_with(name=model_path, device='cpu', jit=False, download_root=ModelCache.clip_cache_path) + + @patch('marqo.s2_inference.clip_utils.download_model', return_value='downloaded_model.pth') + @patch('marqo.s2_inference.clip_utils.clip.load', return_value=(mock.Mock(), mock.Mock())) + @patch('os.path.isfile', return_value=False) + @patch('validators.url', return_value=True) + def test_load_with_url(self, mock_url_valid, mock_isfile, mock_clip_load, mock_download_model): + model_url = 'http://example.com/model.pth' + clip = CLIP(model_properties={'url': model_url}) + clip.load() + mock_download_model.assert_called_once_with(url=model_url) + mock_clip_load.assert_called_once_with(name='downloaded_model.pth', device='cpu', jit=False, download_root=ModelCache.clip_cache_path) + + @patch('marqo.s2_inference.clip_utils.CLIP._download_from_repo', return_value='downloaded_model.pth') + @patch('marqo.s2_inference.clip_utils.clip.load', return_value=(mock.Mock(), mock.Mock())) + def test_load_with_model_location(self, mock_clip_load, mock_download_from_repo): + model_location = ModelLocation(s3=S3Location(Bucket='some_bucket', Key='some_key')) + clip = CLIP(model_properties={ModelProperties.model_location: model_location.dict()}) + clip.load() + mock_download_from_repo.assert_called_once() + mock_clip_load.assert_called_once_with(name='downloaded_model.pth', device='cpu', jit=False, download_root=ModelCache.clip_cache_path) + +class TestOpenClipLoad(unittest.TestCase): + + @patch('marqo.s2_inference.clip_utils.open_clip.create_model_and_transforms', + return_value=(mock.Mock(), mock.Mock(), mock.Mock())) + def test_load_without_model_properties(self, mock_open_clip_create_model_and_transforms): + """By default laion400m_e32 is loaded...""" + open_clip = OPEN_CLIP() + open_clip.load() + mock_open_clip_create_model_and_transforms.assert_called_once_with( + 'ViT-B-32-quickgelu', pretrained='laion400m_e32', + device='cpu', jit=False, cache_dir=ModelCache.clip_cache_path) + + @patch('marqo.s2_inference.clip_utils.open_clip.create_model_and_transforms', + return_value=(mock.Mock(), mock.Mock(), mock.Mock())) + @patch('os.path.isfile', return_value=True) + def test_load_with_local_file(self, mock_isfile, mock_open_clip_create_model_and_transforms): + model_path = 'localfile.pth' + open_clip = OPEN_CLIP(model_properties={'localpath': model_path}) + open_clip.load() + mock_open_clip_create_model_and_transforms.assert_called_once_with( + model_name=open_clip.model_name, jit=False, pretrained=model_path, + precision='fp32', image_mean=None, image_std=None, + device='cpu', cache_dir=ModelCache.clip_cache_path) + + @patch('marqo.s2_inference.clip_utils.open_clip.create_model_and_transforms', + return_value=(mock.Mock(), mock.Mock(), mock.Mock())) + @patch('validators.url', return_value=True) + @patch('marqo.s2_inference.clip_utils.download_model', return_value='model.pth') + def test_load_with_url(self, mock_download_model, mock_validators_url, mock_open_clip_create_model_and_transforms): + model_url = 'http://model.com/model.pth' + open_clip = OPEN_CLIP(model_properties={'url': model_url}) + open_clip.load() + mock_download_model.assert_called_once_with(url=model_url) + mock_open_clip_create_model_and_transforms.assert_called_once_with( + model_name=open_clip.model_name, jit=False, pretrained='model.pth', precision='fp32', + image_mean=None, image_std=None, device='cpu', cache_dir=ModelCache.clip_cache_path) + + @patch('marqo.s2_inference.clip_utils.open_clip.create_model_and_transforms', + return_value=(mock.Mock(), mock.Mock(), mock.Mock())) + @patch('marqo.s2_inference.clip_utils.CLIP._download_from_repo', + return_value='model.pth') + def test_load_with_model_location(self, mock_download_from_repo, mock_open_clip_create_model_and_transforms): + open_clip = OPEN_CLIP(model_properties={ + ModelProperties.model_location: ModelLocation( + auth_required=True, hf=HfModelLocation(repo_id='someId', filename='some_file.pt')).dict()}) + open_clip.load() + mock_download_from_repo.assert_called_once() + mock_open_clip_create_model_and_transforms.assert_called_once_with( + model_name=open_clip.model_name, jit=False, pretrained='model.pth', precision='fp32', + image_mean=None, image_std=None, device='cpu', cache_dir=ModelCache.clip_cache_path) diff --git a/tests/s2_inference/test_custom_clip_utils.py b/tests/s2_inference/test_custom_clip_utils.py new file mode 100644 index 000000000..4285beb6a --- /dev/null +++ b/tests/s2_inference/test_custom_clip_utils.py @@ -0,0 +1,200 @@ +import unittest +import urllib +from unittest.mock import patch, MagicMock +from marqo.s2_inference.processing.custom_clip_utils import ( + download_pretrained_from_s3, download_model, download_pretrained_from_url, + ModelDownloadError, S3Auth, S3Location, ModelAuth, ModelLocation +) +from marqo.s2_inference.errors import InvalidModelPropertiesError +import tempfile +import os + +class TestDownloadModel(unittest.TestCase): + def test_both_location_and_url_provided(self): + with self.assertRaises(InvalidModelPropertiesError): + download_model(repo_location=ModelLocation(s3=S3Location(Bucket="test_bucket", Key="test_key")), url="http://example.com/model.pt") + + def test_neither_location_nor_url_provided(self): + with self.assertRaises(InvalidModelPropertiesError): + download_model() + + @patch("marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_s3") + def test_download_from_s3(self, mock_download_s3): + mock_download_s3.return_value = "/path/to/model.pt" + repo_location = ModelLocation(s3=S3Location(Bucket="test_bucket", Key="test_key")) + auth = ModelAuth(s3=S3Auth(aws_access_key_id="test_access_key", aws_secret_access_key="test_secret_key")) + model_path = download_model(repo_location=repo_location, auth=auth) + + self.assertEqual(model_path, "/path/to/model.pt") + mock_download_s3.assert_called_once_with(location=repo_location.s3, auth=auth.s3, download_dir=None) + + @patch("marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url") + def test_download_from_url(self, mock_download_url): + mock_download_url.return_value = "/path/to/model.pt" + url = "http://example.com/model.pt" + model_path = download_model(url=url) + + self.assertEqual(model_path, "/path/to/model.pt") + mock_download_url.assert_called_once_with(url=url, cache_dir=None) + + +class TestDownloadPretrainedFromS3(unittest.TestCase): + def setUp(self): + self.s3_location = S3Location(Bucket="test_bucket", Key="remote_path/test_key.pt") + self.s3_auth = S3Auth(aws_access_key_id="test_access_key", aws_secret_access_key="test_secret_key") + + @patch("marqo.s2_inference.processing.custom_clip_utils.check_s3_model_already_exists") + def test_model_exists_locally(self, mock_check_s3_model): + mock_check_s3_model.return_value = True + + with patch("marqo.s2_inference.processing.custom_clip_utils.get_s3_model_absolute_cache_path" + ) as mock_get_abs_path: + with patch("marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url" + ) as mock_download_pretrained_from_url: + mock_get_abs_path.return_value = "/path/to/model.pt" + result = download_pretrained_from_s3(location=self.s3_location, auth=self.s3_auth) + + self.assertEqual(result, "/path/to/model.pt") + mock_download_pretrained_from_url.assert_not_called() + mock_check_s3_model.assert_called_once_with(location=self.s3_location) + + @patch("marqo.s2_inference.processing.custom_clip_utils.check_s3_model_already_exists") + @patch("marqo.s2_inference.processing.custom_clip_utils.get_presigned_s3_url") + def test_model_does_not_exist_locally(self, mock_get_presigned_url, mock_check_s3_model): + mock_check_s3_model.return_value = False + mock_get_presigned_url.return_value = "http://example.com/model.pt" + + with patch("marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url" + ) as mock_download_pretrained_from_url: + mock_download_pretrained_from_url.return_value = "/path/to/model.pt" + result = download_pretrained_from_s3(location=self.s3_location, auth=self.s3_auth) + + self.assertEqual(result, "/path/to/model.pt") + mock_download_pretrained_from_url.assert_called() + mock_get_presigned_url.assert_called_once_with(location=self.s3_location, auth=self.s3_auth) + + # note the cache file name is to come from the key rather than the URL + mock_download_pretrained_from_url.assert_called_once_with( + url="http://example.com/model.pt", + cache_dir=None, + # Base name of s3 key: + cache_file_name='test_key.pt' + ) + + @patch("marqo.s2_inference.processing.custom_clip_utils.check_s3_model_already_exists") + @patch("marqo.s2_inference.processing.custom_clip_utils.get_presigned_s3_url") + def test_model_download_raises_403_error(self, mock_get_presigned_url, mock_check_s3_model): + mock_check_s3_model.return_value = False + mock_get_presigned_url.return_value = "http://example.com/model.pt" + + with patch("marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url") as mock_download_url: + mock_download_url.side_effect = urllib.error.HTTPError(url=None, code=403, msg=None, hdrs=None, fp=None) + + with self.assertRaises(ModelDownloadError): + download_pretrained_from_s3(location=self.s3_location, auth=self.s3_auth) + +class TestDownloadPretrainedFromURL(unittest.TestCase): + def setUp(self): + self.url = "http://example.com/model.pt" + + @patch("urllib.request.urlopen") + @patch("os.path.isfile") + def test_file_exists_locally(self, mock_isfile, mock_urlopen): + mock_isfile.return_value = True + with patch("builtins.open", unittest.mock.mock_open()) as mock_open: + with patch("marqo.s2_inference.processing.custom_clip_utils.tqdm") as mock_tqdm: + with patch("marqo.s2_inference.processing.custom_clip_utils.ModelCache") as mock_cache: + with tempfile.TemporaryDirectory() as temp_cache_dir: + mock_cache.clip_cache_path = temp_cache_dir + result = download_pretrained_from_url(self.url) + + self.assertEqual(result, os.path.join(temp_cache_dir, 'model.pt')) + mock_urlopen.assert_not_called() + mock_isfile.assert_called_once() + + @patch("os.path.isfile") + @patch("urllib.request.urlopen") + def test_file_does_not_exist_locally(self, mock_urlopen, mock_isfile): + mock_isfile.return_value = False + mock_source = MagicMock() + mock_source.headers.get.return_value = 0 + mock_source.read.return_value = b'' + mock_urlopen.return_value.__enter__.return_value = mock_source + + with patch("builtins.open", unittest.mock.mock_open()) as mock_open: + with patch("marqo.s2_inference.processing.custom_clip_utils.tqdm") as mock_tqdm: + with patch("marqo.s2_inference.processing.custom_clip_utils.ModelCache") as mock_cache: + with tempfile.TemporaryDirectory() as temp_cache_dir: + mock_cache.clip_cache_path = temp_cache_dir + result = download_pretrained_from_url(self.url) + + self.assertEqual(result, os.path.join(temp_cache_dir, 'model.pt')) + mock_isfile.assert_called_once() + mock_urlopen.assert_called_once_with(self.url) + mock_open.assert_called_once_with(os.path.join(temp_cache_dir, 'model.pt'), "wb") + + @patch("os.path.isfile") + @patch("urllib.request.urlopen") + def test_file_does_not_exist_locally_custom_filename(self, mock_urlopen, mock_isfile): + mock_isfile.return_value = False + mock_source = MagicMock() + mock_source.headers.get.return_value = 0 + mock_source.read.return_value = b'' + mock_urlopen.return_value.__enter__.return_value = mock_source + + with patch("builtins.open", unittest.mock.mock_open()) as mock_open: + with patch("marqo.s2_inference.processing.custom_clip_utils.tqdm") as mock_tqdm: + with patch("marqo.s2_inference.processing.custom_clip_utils.ModelCache") as mock_cache: + with tempfile.TemporaryDirectory() as temp_cache_dir: + mock_cache.clip_cache_path = temp_cache_dir + result = download_pretrained_from_url(self.url, cache_file_name='unusual_model.pt') + + self.assertEqual(result, os.path.join(temp_cache_dir, 'unusual_model.pt')) + mock_isfile.assert_called_once() + mock_urlopen.assert_called_once_with(self.url) + mock_open.assert_called_once_with(os.path.join(temp_cache_dir, 'unusual_model.pt'), "wb") + + @patch("os.path.isfile") + @patch("urllib.request.urlopen") + def test_file_does_not_exist_locally_custom_cache_dir(self, mock_urlopen, mock_isfile): + mock_isfile.return_value = False + mock_source = MagicMock() + mock_source.headers.get.return_value = 0 + mock_source.read.return_value = b'' + mock_urlopen.return_value.__enter__.return_value = mock_source + + with patch("builtins.open", unittest.mock.mock_open()) as mock_open: + with patch("marqo.s2_inference.processing.custom_clip_utils.tqdm") as mock_tqdm: + with patch("marqo.s2_inference.processing.custom_clip_utils.ModelCache") as mock_cache: + with tempfile.TemporaryDirectory() as temp_cache_dir: + custom_dir = os.path.join(temp_cache_dir, 'special/cache') + mock_cache.clip_cache_path = temp_cache_dir + result = download_pretrained_from_url(self.url, cache_dir=custom_dir) + + self.assertEqual(result, os.path.join(custom_dir, 'model.pt')) + mock_isfile.assert_called_once() + mock_urlopen.assert_called_once_with(self.url) + mock_open.assert_called_once_with(os.path.join(custom_dir, 'model.pt'), "wb") + + @patch("os.path.isfile") + @patch("urllib.request.urlopen") + def test_file_does_not_exist_locally_custom_cache_path(self, mock_urlopen, mock_isfile): + mock_isfile.return_value = False + mock_source = MagicMock() + mock_source.headers.get.return_value = 0 + mock_source.read.return_value = b'' + mock_urlopen.return_value.__enter__.return_value = mock_source + + with patch("builtins.open", unittest.mock.mock_open()) as mock_open: + with patch("marqo.s2_inference.processing.custom_clip_utils.tqdm") as mock_tqdm: + with patch("marqo.s2_inference.processing.custom_clip_utils.ModelCache") as mock_cache: + with tempfile.TemporaryDirectory() as temp_cache_dir: + custom_dir = os.path.join(temp_cache_dir, 'special/cache') + mock_cache.clip_cache_path = temp_cache_dir + result = download_pretrained_from_url( + self.url, cache_dir=custom_dir, cache_file_name='unusual_model.pt') + + self.assertEqual(result, os.path.join(custom_dir, 'unusual_model.pt')) + mock_isfile.assert_called_once() + mock_urlopen.assert_called_once_with(self.url) + mock_open.assert_called_once_with(os.path.join(custom_dir, 'unusual_model.pt'), "wb") \ No newline at end of file diff --git a/tests/s2_inference/test_generic_clip_model.py b/tests/s2_inference/test_generic_clip_model.py index bb9a55f3d..f979a1612 100644 --- a/tests/s2_inference/test_generic_clip_model.py +++ b/tests/s2_inference/test_generic_clip_model.py @@ -1,5 +1,5 @@ import numpy as np - +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.errors import IndexNotFoundError from marqo.s2_inference.errors import UnknownModelError, ModelLoadError from marqo.tensor_search import tensor_search @@ -65,7 +65,9 @@ def test_create_index_and_add_documents_with_generic_open_clip_model_properties_ }] auto_refresh = True - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh) + ) # test if we can get the document by _id assert tensor_search.get_document_by_id( @@ -84,8 +86,8 @@ def test_create_index_and_add_documents_with_generic_open_clip_model_properties_ "desc 2": "test again test again test again" }] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs2, - auto_refresh=auto_refresh) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs2, auto_refresh=auto_refresh)) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -130,7 +132,9 @@ def test_pipeline_with_generic_openai_clip_model_properties_url(self): }] auto_refresh = True - tensor_search.add_documents(config=self.config, index_name=self.index_name_2, docs=docs, auto_refresh=auto_refresh) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_2, docs=docs, auto_refresh=auto_refresh + )) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_2, @@ -147,8 +151,8 @@ def test_pipeline_with_generic_openai_clip_model_properties_url(self): "desc 2": "test again test again test again" }] - tensor_search.add_documents(config=self.config, index_name=self.index_name_2, docs=docs2, - auto_refresh=auto_refresh) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_2, docs=docs2, auto_refresh=auto_refresh)) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_2, @@ -196,8 +200,8 @@ def test_pipeline_with_generic_open_clip_model_properties_localpath(self): }] auto_refresh = True - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs, - auto_refresh=auto_refresh) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh)) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -214,8 +218,8 @@ def test_pipeline_with_generic_open_clip_model_properties_localpath(self): "desc 2": "test again test again test again" }] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs2, - auto_refresh=auto_refresh) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs2, auto_refresh=auto_refresh)) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -318,7 +322,8 @@ def test_add_documents_text_and_image(self): }] auto_refresh = True - tensor_search.add_documents(config=config, index_name=index_name, docs=docs, auto_refresh=auto_refresh) + tensor_search.add_documents(config=config, add_docs_params=AddDocsParams( + index_name=index_name, docs=docs, auto_refresh=auto_refresh)) def test_load_generic_clip_without_url_or_localpath(self): diff --git a/tests/s2_inference/test_generic_model.py b/tests/s2_inference/test_generic_model.py index 1076c4b92..4ae2ea24e 100644 --- a/tests/s2_inference/test_generic_model.py +++ b/tests/s2_inference/test_generic_model.py @@ -1,5 +1,5 @@ import numpy as np - +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.errors import IndexNotFoundError from marqo.s2_inference.errors import InvalidModelPropertiesError, UnknownModelError, ModelLoadError from marqo.tensor_search import tensor_search @@ -92,7 +92,8 @@ def test_add_documents(self): }] auto_refresh = True - tensor_search.add_documents(config=config, index_name=index_name, docs=docs, auto_refresh=auto_refresh) + tensor_search.add_documents(config=config, add_docs_params=AddDocsParams( + index_name=index_name, docs=docs, auto_refresh=auto_refresh)) def test_validate_model_properties_missing_required_keys(self): """_validate_model_properties should throw an exception if required keys are not given. @@ -141,7 +142,8 @@ def test_validate_model_properties_missing_properties(self): "type": "test", "notes": ""} - validated_model_properties = _validate_model_properties(model_name=model_name, model_properties=None) + validated_model_properties = _validate_model_properties( + model_name=model_name, model_properties=None) self.assertEqual(registry_test_model_properties, validated_model_properties) diff --git a/tests/tensor_search/models/test_private_models.py b/tests/tensor_search/models/test_private_models.py new file mode 100644 index 000000000..6082f14bb --- /dev/null +++ b/tests/tensor_search/models/test_private_models.py @@ -0,0 +1,51 @@ +import unittest +from marqo.tensor_search.models.private_models import ModelAuth, ModelLocation +from marqo.errors import InvalidArgError +from marqo.tensor_search.models.external_apis.hf import HfAuth, HfModelLocation +from marqo.tensor_search.models.external_apis.s3 import S3Auth, S3Location + +class TestModelAuth(unittest.TestCase): + def test_no_auth(self): + with self.assertRaises(InvalidArgError): + ModelAuth() + + def test_multiple_auth(self): + with self.assertRaises(InvalidArgError): + ModelAuth( + s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test"), + hf=HfAuth(token="test")) + + def test_s3_auth(self): + try: + ModelAuth(s3=S3Auth(aws_secret_access_key="test", aws_access_key_id="test")) + except InvalidArgError: + self.fail("ModelAuth raised InvalidArgError unexpectedly!") + + def test_hf_auth(self): + try: + ModelAuth(hf=HfAuth(token="test")) + except InvalidArgError: + self.fail("ModelAuth raised InvalidArgError unexpectedly!") + +class TestModelLocation(unittest.TestCase): + def test_no_location(self): + with self.assertRaises(InvalidArgError): + ModelLocation() + + def test_multiple_locations(self): + with self.assertRaises(InvalidArgError): + ModelLocation( + s3=S3Location(Bucket="test", Key="test"), + hf=HfModelLocation(repo_id="test", filename="test")) + + def test_s3_location(self): + try: + ModelLocation(s3=S3Location(Bucket="test", Key="test")) + except InvalidArgError: + self.fail("ModelLocation raised InvalidArgError unexpectedly!") + + def test_hf_location(self): + try: + ModelLocation(hf=HfModelLocation(repo_id="test", filename="test")) + except InvalidArgError: + self.fail("ModelLocation raised InvalidArgError unexpectedly!") diff --git a/tests/tensor_search/test__httprequests.py b/tests/tensor_search/test__httprequests.py index 2db4ba913..adf3b8334 100644 --- a/tests/tensor_search/test__httprequests.py +++ b/tests/tensor_search/test__httprequests.py @@ -1,6 +1,6 @@ import requests from tests.marqo_test import MarqoTestCase -from marqo import _httprequests +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from unittest import mock from marqo.tensor_search import tensor_search from marqo.errors import ( @@ -31,8 +31,10 @@ def test_too_many_reqs_error(self): def run(): try: res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some ": "doc"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[{"some ": "doc"}], auto_refresh=True + ) ) raise AssertionError except TooManyRequestsError: diff --git a/tests/tensor_search/test_add_documents.py b/tests/tensor_search/test_add_documents.py index 9fde50886..743451f7c 100644 --- a/tests/tensor_search/test_add_documents.py +++ b/tests/tensor_search/test_add_documents.py @@ -1,5 +1,5 @@ import copy -import fileinput +from marqo.tensor_search.models.add_docs_objects import AddDocsParams import functools import json import math @@ -51,13 +51,18 @@ def _match_all(self, index_name, verbose=False): def test_add_plain_id_field(self): """does a plain 'id' field work (in the doc body)? """ tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "123", - "id": "abcdefgh", - "title 1": "content 1", - "desc 2": "content 2. blah blah blah" - }], auto_refresh=True) + tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[{ + "_id": "123", + "id": "abcdefgh", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }], + auto_refresh=True + ) + ) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123") == { @@ -72,38 +77,44 @@ def test_add_documents_dupe_ids(self): Should only use the latest inserted ID. Make sure it doesn't get the first/middle one """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "3", - "title": "doc 3b" - }, - - ], auto_refresh=True) - + tensor_search.add_documents( + config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{ + "_id": "3", + "title": "doc 3b" + }], + auto_refresh=True + ) + ) doc_3_solo = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="3", show_vectors=True) tensor_search.delete_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "1", - "title": "doc 1" - }, - { - "_id": "2", - "title": "doc 2", - }, - { - "_id": "3", - "title": "doc 3a", - }, - { - "_id": "3", - "title": "doc 3b" - }, - - ], auto_refresh=True) + tensor_search.add_documents( + config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "_id": "1", + "title": "doc 1" + }, + { + "_id": "2", + "title": "doc 2", + }, + { + "_id": "3", + "title": "doc 3a", + }, + { + "_id": "3", + "title": "doc 3b" + }], + auto_refresh=True + ) + ) doc_3_duped = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -116,12 +127,16 @@ def test_update_docs_update_chunks(self): """Updating a doc needs to update the corresponding chunks" """ tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "123", - "title 1": "content 1", - "desc 2": "content 2. blah blah blah" - }], auto_refresh=True) + tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "_id": "123", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }], + auto_refresh=True) + ) count0_res = requests.post( F"{self.endpoint}/{self.index_name_1}/_count", timeout=self.config.timeout, @@ -129,12 +144,18 @@ def test_update_docs_update_chunks(self): ) count0 = count0_res.json()["count"] assert count0 == 1 - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "123", - "title 1": "content 1", - "desc 2": "content 2. blah blah blah" - }], auto_refresh=True) + tensor_search.add_documents( + config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[{ + "_id": "123", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }], + auto_refresh=True + ) + ) count1_res = requests.post( F"{self.endpoint}/{self.index_name_1}/_count", timeout=self.config.timeout, @@ -150,7 +171,9 @@ def test_implicit_create_index(self): ) assert r1.status_code == 404 add_doc_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + ) ) r2 = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -171,7 +194,10 @@ def test_default_index_settings(self): def test_default_index_settings_implicitly_created(self): add_doc_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + ) ) index_info = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -184,7 +210,9 @@ def test_default_index_settings_implicitly_created(self): def test_add_new_fields_on_the_fly(self): add_doc_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + ) ) cluster_ix_info = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -194,7 +222,9 @@ def test_add_new_fields_on_the_fly(self): assert "__vector_abc" in cluster_ix_info.json()[self.index_name_1]["mappings"]["properties"][TensorField.chunks]["properties"] assert "dimension" in cluster_ix_info.json()[self.index_name_1]["mappings"]["properties"][TensorField.chunks]["properties"]["__vector_abc"] add_doc_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "1234", "The title book 1": "hahehehe"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"abc": "1234", "The title book 1": "hahehehe"}], auto_refresh=True + ) ) cluster_ix_info_2 = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -213,7 +243,9 @@ def test_add_new_fields_on_the_fly_index_cache_syncs(self): verify=False ) add_doc_res_1 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + ) ) index_info_2 = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -222,7 +254,9 @@ def test_add_new_fields_on_the_fly_index_cache_syncs(self): assert index_meta_cache.get_cache()[self.index_name_1].properties[TensorField.chunks]["properties"]["__vector_abc"] \ == index_info_2.json()[self.index_name_1]["mappings"]["properties"][TensorField.chunks]["properties"]["__vector_abc"] add_doc_res_2 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], auto_refresh=True + ) ) index_info_3 = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -234,7 +268,10 @@ def test_add_new_fields_on_the_fly_index_cache_syncs(self): def test_add_multiple_fields(self): add_doc_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"cool v field": "yep yep", "haha ee": "heheh"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"cool v field": "yep yep", "haha ee": "heheh"}], + auto_refresh=True + ) ) cluster_ix_info = requests.get( url=f"{self.endpoint}/{self.index_name_1}", @@ -249,25 +286,28 @@ def test_add_multiple_fields(self): def test_add_docs_response_format(self): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - add_res = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "123", - "id": "abcdefgh", - "title 1": "content 1", - "desc 2": "content 2. blah blah blah" - }, - { - "_id": "456", - "id": "abcdefgh", - "title 1": "content 1", - "desc 2": "content 2. blah blah blah" - }, - { - "_id": "789", - "subtitle": [1, 2, 3] - } - ], auto_refresh=True) + + add_res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "_id": "123", + "id": "abcdefgh", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }, + { + "_id": "456", + "id": "abcdefgh", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }, + { + "_id": "789", + "subtitle": [1, 2, 3] + } + ], auto_refresh=True) + ) assert "errors" in add_res assert "processingTimeMs" in add_res assert "index_name" in add_res @@ -308,8 +348,11 @@ def test_add_documents_validation(self): # For update for bad_doc_arg in bad_doc_args: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg, auto_refresh=True, update_mode='update') + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=bad_doc_arg, auto_refresh=True, update_mode='update' + ) + ) assert add_res['errors'] is True assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')]) assert all(['result' in item @@ -319,8 +362,11 @@ def test_add_documents_validation(self): for use_existing_tensors_flag in (True, False): for bad_doc_arg in bad_doc_args: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg, auto_refresh=True, update_mode='replace', use_existing_tensors=use_existing_tensors_flag) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=bad_doc_arg, auto_refresh=True, + update_mode='replace', use_existing_tensors=use_existing_tensors_flag + ) + ) assert add_res['errors'] is True assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')]) assert all(['result' in item @@ -346,9 +392,11 @@ def test_add_documents_id_validation(self): # For update for bad_doc_arg in bad_doc_args: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg[0], auto_refresh=True, update_mode='update') - + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=bad_doc_arg[0], + auto_refresh=True, update_mode='update' + ) + ) assert add_res['errors'] is True succeeded_count = 0 @@ -362,8 +410,11 @@ def test_add_documents_id_validation(self): for use_existing_tensors_flag in (True, False): for bad_doc_arg in bad_doc_args: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg[0], auto_refresh=True, update_mode='replace', use_existing_tensors=use_existing_tensors_flag) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=bad_doc_arg[0], auto_refresh=True, + update_mode='replace', use_existing_tensors=use_existing_tensors_flag + ) + ) assert add_res['errors'] is True succeeded_count = 0 for item in add_res['items']: @@ -380,8 +431,11 @@ def test_add_documents_list_non_tensor_validation(self): for update_mode in ('replace', 'update'): for bad_doc_arg in bad_doc_args: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=bad_doc_arg, + auto_refresh=True, update_mode=update_mode + ) + ) assert add_res['errors'] is True assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')]) @@ -392,9 +446,12 @@ def test_add_documents_list_success(self): for update_mode in ('replace', 'update'): for bad_doc_arg in good_docs: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode, - non_tensor_fields=["my_field"]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode, + non_tensor_fields=["my_field"] + ) + ) assert add_res['errors'] is False def test_add_documents_list_data_type_validation(self): @@ -407,9 +464,12 @@ def test_add_documents_list_data_type_validation(self): for update_mode in ('replace', 'update'): for bad_doc_arg in bad_doc_args: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode, - non_tensor_fields=["my_field"]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode, + non_tensor_fields=["my_field"] + ) + ) assert add_res['errors'] is True assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')]) @@ -425,8 +485,11 @@ def test_add_documents_set_device(self): @mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, device="cuda:411", docs=[{"some": "doc"}], - auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, device="cuda:411", docs=[{"some": "doc"}], + auto_refresh=True + ) + ) return True assert run() @@ -445,8 +508,12 @@ def test_add_documents_orchestrator_set_device_single_process(self): @mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, device="cuda:22", docs=[{"some": "doc"}, {"som other": "doc"}], - auto_refresh=True, batch_size=1, processes=1) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, device="cuda:22", docs=[{"some": "doc"}, {"som other": "doc"}], + auto_refresh=True, + ), + batch_size=1, processes=1 + ) return True assert run() @@ -465,8 +532,12 @@ def test_add_documents_orchestrator_set_device_empty_batch(self): @mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, device="cuda:22", docs=[{"some": "doc"}, {"som other": "doc"}], - auto_refresh=True, batch_size=0) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, device="cuda:22", docs=[{"some": "doc"}, {"som other": "doc"}], + auto_refresh=True + ), + batch_size=0 + ) return True assert run() @@ -477,8 +548,10 @@ def run(): def test_add_documents_empty(self): try: tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[], - auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[], + auto_refresh=True) + ) raise AssertionError except BadRequestError: pass @@ -529,7 +602,8 @@ def test_resilient_add_images(self): ), ] for docs, expected_results in docs_results: - add_res = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True) + add_res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=True)) assert len(add_res['items']) == len(expected_results) for i, res_dict in enumerate(add_res['items']): assert res_dict["_id"] == expected_results[i][0] @@ -579,8 +653,10 @@ def test_add_documents_resilient_doc_validation(self): for update_mode in ('update', 'replace'): for docs, expected_results in docs_results: add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True, - update_mode=update_mode + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=True, + update_mode=update_mode + ) ) assert len(add_res['items']) == len(expected_results) for i, res_dict in enumerate(add_res['items']): @@ -633,8 +709,10 @@ def test_mappings_arent_updated(self): for docs, (good_fields, bad_fields) in docs_results: # good_fields should appear in the mapping. # bad_fields should not - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, - docs=docs, auto_refresh=True, update_mode=update_mode) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=True, update_mode=update_mode + ) + ) ii = backend.get_index_info(config=self.config, index_name=self.index_name_1) customer_props = {field_name for field_name in ii.get_text_properties()} reduced_vector_props = {field_name.replace(TensorField.vector_prefix, '') @@ -690,7 +768,9 @@ def test_mappings_arent_updated_images(self): for docs, (good_fields, bad_fields) in docs_results: # good_fields should appear in the mapping. # bad_fields should not - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=True) + ) ii = backend.get_index_info(config=self.config, index_name=self.index_name_1) customer_props = {field_name for field_name in ii.get_text_properties()} reduced_vector_props = {field_name.replace(TensorField.vector_prefix, '') @@ -705,11 +785,14 @@ def test_mappings_arent_updated_images(self): tensor_search.delete_index(config=self.config, index_name=self.index_name_1) def patch_documents_tests(self, docs_, update_docs, get_docs): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True + )) update_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=update_docs, - auto_refresh=True, update_mode='update') - + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=update_docs, + auto_refresh=True, update_mode='update' + )) for doc_id, check_dict in get_docs.items(): updated_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id=doc_id @@ -844,7 +927,8 @@ def test_put_documents_no_outdated_chunks(self): 3) there are no dangling chunks """ docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True)) original_doc = requests.get( url=F"{self.endpoint}/{self.index_name_1}/_doc/789", verify=False @@ -853,10 +937,13 @@ def test_put_documents_no_outdated_chunks(self): description_chunk = [chunk for chunk in original_doc.json()['_source']['__chunks'] if chunk['__field_name'] == 'Description'][0] update_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "789", "Title": "Story of Alice Appleseed", - "Description": "Alice grew up in Rooster, Texas."}], - auto_refresh=True, update_mode='update') + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{ + "_id": "789", "Title": "Story of Alice Appleseed", + "Description": "Alice grew up in Rooster, Texas."}], + auto_refresh=True, update_mode='update' + ) + ) updated_doc = requests.get( url=F"{self.endpoint}/{self.index_name_1}/_doc/789", verify=False @@ -878,13 +965,22 @@ def test_put_documents_search(self): """Can we search with the new vectors """ docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True)) search_str = "Who is an alien?" first_search = tensor_search.search(config=self.config, index_name=self.index_name_1, text=search_str) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "789", "Title": "Story of Alice Appleseed", - "Description": "Unbeknownst to most, Alice is actually an alien in disguise. She uses a UFO to commute to work."} - ], auto_refresh=True, update_mode='update') + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[ + { + "_id": "789", "Title": "Story of Alice Appleseed", + "Description": "Unbeknownst to most, Alice is actually an alien in disguise. " + "She uses a UFO to commute to work." + } + ], + auto_refresh=True, update_mode='update' + ) + ) second_search = tensor_search.search(config=self.config, index_name=self.index_name_1, text=search_str) assert not np.isclose(first_search["hits"][0]["_score"], second_search["hits"][0]["_score"]) assert second_search["hits"][0]["_score"] > first_search["hits"][0]["_score"] @@ -893,10 +989,15 @@ def test_put_documents_search_new_fields(self): """Can we search with the new field? """ docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "789", "Title": "Story of Alice Appleseed", "Favourite Wavelength": "2 microns"} - ], auto_refresh=True, update_mode='update') + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True)) + tensor_search.add_documents(config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"_id": "789", "Title": "Story of Alice Appleseed", "Favourite Wavelength": "2 microns"}], + auto_refresh=True, update_mode='update' + ) + ) searched = tensor_search.search( config=self.config, index_name=self.index_name_1, text="A very small length", searchable_attributes=['Favourite Wavelength'] @@ -906,8 +1007,11 @@ def test_put_documents_search_new_fields(self): def patch_documents_filtering_test(self, original_add_docs, update_add_docs, filter_string, expected_ids: set): """Helper for filtering tests""" - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=original_add_docs, auto_refresh=True) - res = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=update_add_docs, auto_refresh=True, update_mode='update') + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=original_add_docs, auto_refresh=True)) + res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=update_add_docs, auto_refresh=True, update_mode='update' + )) abc = requests.get( url=F"{self.endpoint}/{self.index_name_1}/_doc/789", @@ -970,8 +1074,10 @@ def test_put_documents_filtering_int(self): def test_put_document_override_non_tensor_field(self): docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"]) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"])) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True)) resp = tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="789", show_vectors=True) assert len(resp[enums.TensorField.tensor_facets]) == 2 @@ -983,7 +1089,9 @@ def test_put_document_override_non_tensor_field(self): def test_add_document_with_non_tensor_field(self): docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"]) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"] + )) resp = tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="789", show_vectors=True) assert len(resp[enums.TensorField.tensor_facets]) == 1 @@ -992,10 +1100,10 @@ def test_add_document_with_non_tensor_field(self): assert "Description" in resp[enums.TensorField.tensor_facets][0] def test_put_no_update(self): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[{'_id':'123'}], - auto_refresh=True, update_mode='replace') - res = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[{'_id':'123'}], - auto_refresh=True, update_mode='replace') + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{'_id':'123'}], auto_refresh=True, update_mode='replace')) + res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{'_id':'123'}], auto_refresh=True, update_mode='replace')) assert {'_id':'123'} == tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id='123') @@ -1017,45 +1125,6 @@ def test_put_no_update_existing_field_float(self): config=self.config, index_name=self.index_name_1, document_id='123') assert {'_id': '123', "the_float": 20.22} == get_res - def test_put_documents_orchestrator(self): - """ - """ - docs_ = [ - {"_id": "123", "Title": "Story of Joe Blogs", "Description": "Joe was a great farmer."}, - {"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."} - ] - - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=docs_, auto_refresh=True) - update_res = tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "789", "Title": "Woohoo", "Mega": "Coool"}, - {"_id": "789", "Luminosity": "Extreme"}, - {"_id": "789", "Temp": 12.5}, - ], - auto_refresh=True, update_mode='update', processes=4, batch_size=1) - time.sleep(5) - updated_doc = tensor_search.get_document_by_id( - config=self.config, index_name=self.index_name_1, document_id='789' - ) - check_dict = {"_id": '789', "Temp": 12.5, "Luminosity": "Extreme", "Title": "Woohoo", "Mega": "Coool"} - for field, expected_value in check_dict.items(): - assert updated_doc[field] == expected_value - - updated_raw_doc = requests.get( - url=F"{self.endpoint}/{self.index_name_1}/_doc/789", - verify=False - ) - check_dict_no_id = copy.deepcopy(check_dict) - try: - del check_dict_no_id['_id'] - except KeyError: - pass - # make sure that the chunks have been updated - for ch in updated_raw_doc.json()['_source']['__chunks']: - assert '_id' not in ch - for field, expected_value in check_dict_no_id.items(): - assert ch[field] == expected_value - def test_doc_too_large(self): max_size = 400000 mock_environ = {enums.EnvVars.MARQO_MAX_DOC_BYTES: str(max_size)} @@ -1063,12 +1132,14 @@ def test_doc_too_large(self): @mock.patch("os.environ", mock_environ) def run(): update_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ {"_id": "123", 'Bad field': "edf " * (max_size // 4)}, {"_id": "789", "Breaker": "abc " * ((max_size // 4) - 500)}, {"_id": "456", "Luminosity": "exc " * (max_size // 4)}, ], - auto_refresh=True, update_mode='update') + auto_refresh=True, update_mode='update' + )) items = update_res['items'] assert update_res['errors'] assert 'error' in items[0] and 'error' in items[2] @@ -1085,10 +1156,12 @@ def test_doc_too_large_single_doc(self): @mock.patch("os.environ", mock_environ) def run(): update_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ {"_id": "123", 'Bad field': "edf " * (max_size // 4)}, ], - auto_refresh=True, update_mode='update') + auto_refresh=True, update_mode='update') + ) items = update_res['items'] assert update_res['errors'] assert 'error' in items[0] @@ -1101,10 +1174,12 @@ def test_doc_too_large_none_env_var(self): @mock.patch("os.environ", env_dict) def run(): update_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ {"_id": "123", 'Some field': "Some content"}, - ], - auto_refresh=True, update_mode='update') + ], + auto_refresh=True, update_mode='update' + )) items = update_res['items'] assert not update_res['errors'] assert 'error' not in items[0] @@ -1116,9 +1191,10 @@ def test_non_tensor_field_list(self): test_doc = {"_id": "123", "my_list": ["data1", "mydata"], "myfield2": "mydata2"} tensor_search.add_documents( self.config, - docs=[test_doc], - auto_refresh=True, index_name=self.index_name_1, non_tensor_fields=['my_list'] - ) + add_docs_params=AddDocsParams( + docs=[test_doc], auto_refresh=True, + index_name=self.index_name_1, non_tensor_fields=['my_list'] + )) doc_w_facets = tensor_search.get_document_by_id( self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) @@ -1158,15 +1234,18 @@ def test_non_tensor_field_list(self): def test_no_tensor_field_replace(self): # test replace and update workflows tensor_search.add_documents( - self.config, - docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], - auto_refresh=True, index_name=self.index_name_1 + self.config, add_docs_params=AddDocsParams( + docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], + auto_refresh=True, index_name=self.index_name_1 + ) ) tensor_search.add_documents( self.config, - docs=[{"_id": "123", "myfield": "mydata"}], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"] + add_docs_params=AddDocsParams( + docs=[{"_id": "123", "myfield": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"] + ) ) doc_w_facets = tensor_search.get_document_by_id( self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) @@ -1176,15 +1255,17 @@ def test_no_tensor_field_replace(self): def test_no_tensor_field_update(self): # test replace and update workflows tensor_search.add_documents( - self.config, - docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], - auto_refresh=True, index_name=self.index_name_1 + self.config, add_docs_params=AddDocsParams( + docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], + auto_refresh=True, index_name=self.index_name_1 + ) ) tensor_search.add_documents( - self.config, - docs=[{"_id": "123", "myfield": "mydata"}], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"], update_mode='update' + self.config, add_docs_params=AddDocsParams( + docs=[{"_id": "123", "myfield": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"], update_mode='update' + ) ) doc_w_facets = tensor_search.get_document_by_id( self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) @@ -1195,10 +1276,11 @@ def test_no_tensor_field_update(self): def test_no_tensor_field_on_empty_ix(self): tensor_search.add_documents( - self.config, - docs=[{"_id": "123", "myfield": "mydata"}], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"] + self.config, add_docs_params=AddDocsParams( + docs=[{"_id": "123", "myfield": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"] + ) ) doc_w_facets = tensor_search.get_document_by_id( self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) @@ -1207,10 +1289,11 @@ def test_no_tensor_field_on_empty_ix(self): def test_no_tensor_field_on_empty_ix_other_field(self): tensor_search.add_documents( - self.config, - docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata"}], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"] + self.config, add_docs_params=AddDocsParams( + docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"] + ) ) doc_w_facets = tensor_search.get_document_by_id( self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) @@ -1263,11 +1346,13 @@ def _check_get_docs(doc_count, some_field_value): ) res1 = tensor_search.add_documents( self.config, - docs=[{"_id": str(doc_num), - "location": hippo_url, - "some_field": "blah"} for doc_num in range(c)], - auto_refresh=True, index_name=self.index_name_1, - update_mode=update_mode + add_docs_params=AddDocsParams( + docs=[{"_id": str(doc_num), + "location": hippo_url, + "some_field": "blah"} for doc_num in range(c)], + auto_refresh=True, index_name=self.index_name_1, + update_mode=update_mode + ) ) assert c == tensor_search.get_stats(self.config, index_name=self.index_name_1)['numberOfDocuments'] @@ -1275,11 +1360,13 @@ def _check_get_docs(doc_count, some_field_value): assert _check_get_docs(doc_count=c, some_field_value='blah') res2 = tensor_search.add_documents( self.config, - docs=[{"_id": str(doc_num), - "location": hippo_url, - "some_field": "blah2"} for doc_num in range(c)], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"], update_mode=update_mode + add_docs_params=AddDocsParams( + docs=[{"_id": str(doc_num), + "location": hippo_url, + "some_field": "blah2"} for doc_num in range(c)], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"], update_mode=update_mode + ) ) assert not res2['errors'] assert c == tensor_search.get_stats(self.config, @@ -1323,23 +1410,26 @@ def _check_get_docs(doc_count, some_field_value): ) res1 = tensor_search.add_documents( self.config, - docs=[{"_id": str(doc_num), - "location": hippo_url, - "some_field": "blah"} for doc_num in range(c)], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["location"], update_mode=update_mode - ) + add_docs_params=AddDocsParams( + docs=[{"_id": str(doc_num), + "location": hippo_url, + "some_field": "blah"} for doc_num in range(c)], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["location"], update_mode=update_mode + )) assert c == tensor_search.get_stats(self.config, index_name=self.index_name_1)['numberOfDocuments'] assert not res1['errors'] assert _check_get_docs(doc_count=c, some_field_value='blah') res2 = tensor_search.add_documents( self.config, - docs=[{"_id": str(doc_num), - "location": hippo_url, - "some_field": "blah2"} for doc_num in range(c)], - auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["location"], update_mode=update_mode + add_docs_params=AddDocsParams( + docs=[{"_id": str(doc_num), + "location": hippo_url, + "some_field": "blah2"} for doc_num in range(c)], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["location"], update_mode=update_mode + ) ) assert not res2['errors'] assert c == tensor_search.get_stats(self.config, diff --git a/tests/tensor_search/test_add_documents_use_existing_tensors.py b/tests/tensor_search/test_add_documents_use_existing_tensors.py index 5e82aef11..631691add 100644 --- a/tests/tensor_search/test_add_documents_use_existing_tensors.py +++ b/tests/tensor_search/test_add_documents_use_existing_tensors.py @@ -1,4 +1,4 @@ -import pprint +from marqo.tensor_search.models.add_docs_objects import AddDocsParams import unittest.mock import requests from tests.marqo_test import MarqoTestCase @@ -29,14 +29,14 @@ def test_use_existing_tensors_resilience(self): } # 1 valid ID doc: res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[d1, {'_id': 1224}, {"_id": "fork", "abc": "123"}], - auto_refresh=True, use_existing_tensors=True) + config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1, {'_id': 1224}, {"_id": "fork", "abc": "123"}], + auto_refresh=True, use_existing_tensors=True)) assert [item['status'] for item in res['items']] == [201, 400, 201] # no valid IDs res_no_valid_id = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[d1, {'_id': 1224}, d1], - auto_refresh=True, use_existing_tensors=True) + config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1, {'_id': 1224}, d1], + auto_refresh=True, use_existing_tensors=True)) # we also should not be send in a get request as there are no valid document IDs assert [item['status'] for item in res_no_valid_id['items']] == [201, 400, 201] @@ -48,11 +48,11 @@ def test_use_existing_tensors_no_id(self): "desc 2": "content 2. blah blah blah" } r1 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[d1], - auto_refresh=True, use_existing_tensors=True) + config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1], + auto_refresh=True, use_existing_tensors=True)) r2 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[d1, d1], - auto_refresh=True, use_existing_tensors=True) + config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1, d1], + auto_refresh=True, use_existing_tensors=True)) for item in r1['items']: assert item['result'] == 'created' @@ -64,12 +64,14 @@ def test_use_existing_tensors_non_existing(self): """check parity between a doc created with and without use_existing_tensors, then overwritten, for a newly created doc. """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=False) + }], auto_refresh=True, use_existing_tensors=False)) + regular_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -77,23 +79,25 @@ def test_use_existing_tensors_non_existing(self): tensor_search.delete_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True) + }], auto_refresh=True, use_existing_tensors=True)) use_existing_tensors_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) self.assertEqual(use_existing_tensors_doc, regular_doc) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True) + }], auto_refresh=True, use_existing_tensors=True)) overwritten_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -105,38 +109,39 @@ def test_use_existing_tensors_dupe_ids(self): Should only use the latest inserted ID. Make sure it doesn't get the first/middle one """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "3", "title": "doc 3b" }, - ], auto_refresh=True) + ], auto_refresh=True)) doc_3_solo = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="3", show_vectors=True) tensor_search.delete_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "1", - "title": "doc 1" - }, - { - "_id": "2", - "title": "doc 2", - }, - { - "_id": "3", - "title": "doc 3a", - }, - { - "_id": "3", - "title": "doc 3b" - }, - - ], auto_refresh=True, use_existing_tensors=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "_id": "1", + "title": "doc 1" + }, + { + "_id": "2", + "title": "doc 2", + }, + { + "_id": "3", + "title": "doc 3a", + }, + { + "_id": "3", + "title": "doc 3b" + }], + auto_refresh=True, use_existing_tensors=True)) doc_3_duped = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -144,7 +149,8 @@ def test_use_existing_tensors_dupe_ids(self): self.assertEqual(doc_3_solo, doc_3_duped) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "1", "title": "doc 1" @@ -162,7 +168,7 @@ def test_use_existing_tensors_dupe_ids(self): "title": "doc 3b" }, - ], auto_refresh=True, use_existing_tensors=True) + ], auto_refresh=True, use_existing_tensors=True)) doc_3_overwritten = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -178,7 +184,8 @@ def test_use_existing_tensors_retensorize_fields(self): They should still have no tensors. """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", @@ -186,20 +193,21 @@ def test_use_existing_tensors_retensorize_fields(self): "title 3": True, "title 4": "content 4" }], auto_refresh=True, use_existing_tensors=True, - non_tensor_fields=["title 1", "title 2", "title 3", "title 4"]) + non_tensor_fields=["title 1", "title 2", "title 3", "title 4"])) d1 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) assert len(d1["_tensor_facets"]) == 0 - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "title 2": 2, "title 3": True, "title 4": "content 4" - }], auto_refresh=True, use_existing_tensors=True) + }], auto_refresh=True, use_existing_tensors=True)) d2 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -212,45 +220,49 @@ def test_use_existing_tensors_getting_non_tensorised(self): When we insert the doc again, with use_existing_tensors, because the content hasn't changed, we use the existing (non-existent) vectors """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"]) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) d1 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) assert len(d1["_tensor_facets"]) == 1 assert "title 1" in d1["_tensor_facets"][0] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True) + }], auto_refresh=True, use_existing_tensors=True)) d2 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) self.assertEqual(d1["_tensor_facets"], d2["_tensor_facets"]) # The only field is a non-tensor field. This makes a chunkless doc. - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "999", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"]) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) d1 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="999", show_vectors=True) assert len(d1["_tensor_facets"]) == 0 - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "999", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True) + }], auto_refresh=True, use_existing_tensors=True)) d2 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="999", show_vectors=True) @@ -259,13 +271,14 @@ def test_use_existing_tensors_getting_non_tensorised(self): def test_use_existing_tensors_check_updates(self): """ Check to see if the document has been appropriately updated """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", "modded field": "original content", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"]) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) def pass_through_vectorise(*arg, **kwargs): """Vectorise will behave as usual, but we will be able to see the call list @@ -277,7 +290,8 @@ def pass_through_vectorise(*arg, **kwargs): mock_vectorise.side_effect = pass_through_vectorise @unittest.mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", # this one should keep the same vectors @@ -285,7 +299,7 @@ def run(): "modded field": "updated content", # new vectors because the content is modified "non-tensor-field": "content 2. blah blah blah", # this would should still have no vectors "2nd-non-tensor-field": "content 2. blah blah blah" # this one is explicitly being non-tensorised - }], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field"], use_existing_tensors=True) + }], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field"], use_existing_tensors=True)) content_to_be_vectorised = [call_kwargs['content'] for call_args, call_kwargs in mock_vectorise.call_args_list] assert content_to_be_vectorised == [["cat on mat"], ["updated content"]] @@ -298,7 +312,8 @@ def test_use_existing_tensors_check_meta_data(self): Checks chunk meta data and vectors are as expected """ - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", @@ -307,7 +322,7 @@ def test_use_existing_tensors_check_meta_data(self): "field_that_will_disappear": "some stuff", # this gets dropped during the next add docs call, "field_to_be_list": "some stuff", "fl": 1.51 - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"]) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) use_existing_tensor_doc = { "title 1": "content 1", # this one should keep the same vectors @@ -322,9 +337,10 @@ def test_use_existing_tensors_check_meta_data(self): "new_bool": False } tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field", "field_to_be_list", 'new_field_list'], - use_existing_tensors=True) + use_existing_tensors=True)) updated_doc = requests.get( url=F"{self.endpoint}/{self.index_name_1}/_doc/123", @@ -348,7 +364,8 @@ def test_use_existing_tensors_check_meta_data(self): assert found_vector_field def test_use_existing_tensors_check_meta_data_mappings(self): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "title 1": "content 1", @@ -357,7 +374,7 @@ def test_use_existing_tensors_check_meta_data_mappings(self): "field_that_will_disappear": "some stuff", # this gets dropped during the next add docs call "field_to_be_list": "some stuff", "fl": 1.51 - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"]) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) use_existing_tensor_doc = { "title 1": "content 1", # this one should keep the same vectors @@ -372,9 +389,9 @@ def test_use_existing_tensors_check_meta_data_mappings(self): "new_bool": False } tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], + config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field", "field_to_be_list", 'new_field_list'], - use_existing_tensors=True) + use_existing_tensors=True)) tensor_search.index_meta_cache.refresh_index(config=self.config, index_name=self.index_name_1) @@ -410,7 +427,8 @@ def test_use_existing_tensors_long_strings_and_images(self): index_name=self.index_name_1, index_settings=index_settings, config=self.config) hippo_img = 'https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png' artefact_hippo_img = 'https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue.png' - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "123", "txt_to_be_the_same": "some text to leave unchanged. I repeat, unchanged", @@ -422,7 +440,7 @@ def test_use_existing_tensors_long_strings_and_images(self): "fl": 1.23, "non-tensor-field": ["what", "is", "the", "time"] - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"]) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) def pass_through_vectorise(*arg, **kwargs): """Vectorise will behave as usual, but we will be able to see the call list @@ -445,9 +463,9 @@ def run(): "non-tensor-field": ["it", "is", "9", "o clock"] } tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], auto_refresh=True, non_tensor_fields=["non-tensor-field"], - use_existing_tensors=True) + use_existing_tensors=True)) vectorised_content = [call_kwargs['content'] for call_args, call_kwargs in mock_vectorise.call_args_list] @@ -523,8 +541,8 @@ def test_use_existing_tensors_all_data_types(self): for doc_arg in doc_args: # Add doc normally without use_existing_tensors add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=doc_arg, auto_refresh=True, update_mode='replace') + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=doc_arg, auto_refresh=True, update_mode='replace')) d1 = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, @@ -532,8 +550,8 @@ def test_use_existing_tensors_all_data_types(self): # Then replace doc with use_existing_tensors add_res = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=doc_arg, auto_refresh=True, update_mode='replace', use_existing_tensors=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=doc_arg, auto_refresh=True, update_mode='replace', use_existing_tensors=True)) d2 = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, diff --git a/tests/tensor_search/test_api_utils.py b/tests/tensor_search/test_api_utils.py index 451fc99b1..920834648 100644 --- a/tests/tensor_search/test_api_utils.py +++ b/tests/tensor_search/test_api_utils.py @@ -1,6 +1,7 @@ -import requests -from marqo.tensor_search import enums, backend -from marqo.tensor_search import tensor_search +import pydantic +from marqo.tensor_search.models.add_docs_objects import ModelAuth +from marqo.tensor_search.models.private_models import S3Auth +import urllib.parse from marqo.tensor_search.web import api_utils from marqo.errors import InvalidArgError, InternalError from tests.marqo_test import MarqoTestCase @@ -43,4 +44,31 @@ def test_generate_config_bad_url(self): c = api_utils.upconstruct_authorized_url(opensearch_url=opensearch_url) raise AssertionError except InternalError: - pass \ No newline at end of file + pass + +class TestDecodeQueryStringModelAuth(MarqoTestCase): + + def test_decode_query_string_model_auth_none(self): + result = api_utils.decode_query_string_model_auth() + self.assertIsNone(result) + + def test_decode_query_string_model_auth_empty_string(self): + result = api_utils.decode_query_string_model_auth("") + self.assertIsNone(result) + + def test_decode_query_string_model_auth_valid(self): + model_auth_obj = ModelAuth(s3=S3Auth( + aws_access_key_id='some_acc_id', aws_secret_access_key='some_sece_key')) + model_auth_str = model_auth_obj.json() + model_auth_url_encoded = urllib.parse.quote_plus(model_auth_str) + + result = api_utils.decode_query_string_model_auth(model_auth_url_encoded) + + self.assertIsInstance(result, ModelAuth) + self.assertEqual(result.s3.aws_access_key_id, 'some_acc_id') + self.assertEqual(result.s3.aws_secret_access_key, 'some_sece_key') + self.assertEqual(result.hf, None) + + def test_decode_query_string_model_auth_invalid(self): + with self.assertRaises(pydantic.ValidationError): + api_utils.decode_query_string_model_auth("invalid_url_encoded_string") \ No newline at end of file diff --git a/tests/tensor_search/test_backend.py b/tests/tensor_search/test_backend.py index 57a26b6d4..f47165b04 100644 --- a/tests/tensor_search/test_backend.py +++ b/tests/tensor_search/test_backend.py @@ -1,7 +1,6 @@ import copy import json -import pprint - +from tests.utils.transition import add_docs_caller import requests from marqo.tensor_search import enums, backend, utils from marqo.tensor_search import tensor_search @@ -62,8 +61,8 @@ def test_add_customer_field_properties_defaults_lucene(self): config=mock_config, index_name=self.index_name_1) @mock.patch("marqo._httprequests.HttpRequests.put", mock__put) def run(): - tensor_search.add_documents(config=mock_config, docs=[{"f1": "doc"}, {"f2":"C"}], - index_name=self.index_name_1, auto_refresh=True) + add_docs_caller(config=mock_config, docs=[{"f1": "doc"}, {"f2":"C"}], + index_name=self.index_name_1, auto_refresh=True) return True assert run() args, kwargs0 = mock__put.call_args_list[0] @@ -79,7 +78,7 @@ def test_add_customer_field_properties_default_ann_parameters(self): config=mock_config, index_name=self.index_name_1) @mock.patch("marqo._httprequests.HttpRequests.put", mock__put) def run(): - tensor_search.add_documents(config=mock_config, docs=[{"f1": "doc"}, {"f2":"C"}], + add_docs_caller(config=mock_config, docs=[{"f1": "doc"}, {"f2":"C"}], index_name=self.index_name_1, auto_refresh=True) return True assert run() @@ -108,7 +107,7 @@ def test_add_customer_field_properties_index_ann_parameters(self): ) @mock.patch("marqo._httprequests.HttpRequests.put", mock__put) def run(): - tensor_search.add_documents(config=mock_config, docs=[{"f1": "doc"}, {"f2":"C"}], + add_docs_caller(config=mock_config, docs=[{"f1": "doc"}, {"f2":"C"}], index_name=self.index_name_1, auto_refresh=True) return True assert run() diff --git a/tests/tensor_search/test_boost_field_scores.py b/tests/tensor_search/test_boost_field_scores.py index b1f810f45..ab1f70c3d 100644 --- a/tests/tensor_search/test_boost_field_scores.py +++ b/tests/tensor_search/test_boost_field_scores.py @@ -1,6 +1,6 @@ from marqo.errors import IndexNotFoundError, InvalidArgError from marqo.tensor_search import tensor_search - +from tests.utils.transition import add_docs_caller from tests.marqo_test import MarqoTestCase @@ -16,7 +16,7 @@ def setUp(self): tensor_search.create_vector_index( index_name=self.index_name_1, config=self.config) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "The Travels of Marco Polo", "Description": "A 13th-century travelogue describing Polo's travels", @@ -100,7 +100,7 @@ def setUp(self): pass def test_boost_multiple_fields(self): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "A comparison of the best pets", "Description": "Animals", @@ -133,7 +133,7 @@ def test_boost_multiple_fields(self): def test_boost_equation_single_field(self): # add a test to check if the score is boosted as expected - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "A comparison of the best pets", "Description": "Animals", @@ -177,7 +177,7 @@ def test_boost_equation_single_field(self): def test_boost_equation_multiple_fields(self): # add a test to check if the score is boosted as expected - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "A comparison of the best pets", "Description": "Animals", @@ -209,7 +209,7 @@ def test_boost_equation_multiple_fields(self): def test_boost_equation_with_multiple_docs(self): # add a test to check if the score is boosted as expected for num_of_doc in range(1,20): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "A comparison of the best pets", "Description": "Animals", @@ -242,10 +242,10 @@ def test_boost_equation_with_multiple_docs(self): index_name=self.index_name_1, config=self.config) - def test_boost_equation_with_multiple_docs(self): + def test_boost_equation_with_multiple_docs_2(self): # add a test to check if the score is boosted as expected for num_of_doc in range(1,20): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "A comparison of the best pets", "Description": "Animals", @@ -280,7 +280,7 @@ def test_boost_equation_with_multiple_docs(self): def test_boost_equation_with_pagination_docs(self): # add a test to check if the score is boosted as expected num_of_doc = 50 - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "A comparison of the best pets", "Description": "Animals", @@ -322,7 +322,7 @@ def test_boost_equation_with_different_fields(self): boost[f"void_field_{i}"] = [1,1] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ docs ] * num_of_doc, auto_refresh=True) diff --git a/tests/tensor_search/test_bulk_search.py b/tests/tensor_search/test_bulk_search.py index d5d273d2b..062a62c3f 100644 --- a/tests/tensor_search/test_bulk_search.py +++ b/tests/tensor_search/test_bulk_search.py @@ -4,6 +4,7 @@ import requests import random from unittest import mock +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.s2_inference.s2_inference import vectorise import unittest from marqo.tensor_search.enums import TensorField, SearchMethod, EnvVars, IndexSettingsField @@ -13,11 +14,11 @@ from marqo.tensor_search import api from marqo.tensor_search.models.api_models import BulkSearchQuery, BulkSearchQueryEntity from marqo.tensor_search import tensor_search, constants, index_meta_cache, utils -from fastapi.exceptions import RequestValidationError import numpy as np from tests.marqo_test import MarqoTestCase from typing import List import pydantic +from tests.utils.transition import add_docs_caller def pass_through_vectorise(*arg, **kwargs): @@ -47,7 +48,7 @@ def _delete_test_indices(self, indices=None): pass def test_bulk_search_w_extra_parameters__raise_exception(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -62,7 +63,7 @@ def test_bulk_search_w_extra_parameters__raise_exception(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '0'}}) def test_bulk_search_with_excessive_searchable_attributes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -78,7 +79,7 @@ def test_bulk_search_with_excessive_searchable_attributes(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '100'}}) def test_bulk_search_with_max_searchable_attributes_no_searchable_attributes_field(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -92,7 +93,7 @@ def test_bulk_search_with_max_searchable_attributes_no_searchable_attributes_fie @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '1'}}) def test_bulk_search_with_excessive_searchable_attributes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -107,7 +108,7 @@ def test_bulk_search_with_excessive_searchable_attributes(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': None}}) def test_bulk_search_with_no_max_searchable_attributes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -121,7 +122,7 @@ def test_bulk_search_with_no_max_searchable_attributes(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': None}}) def test_bulk_search_with_no_max_searchable_attributes_no_searchable_attributes_field(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -134,7 +135,7 @@ def test_bulk_search_with_no_max_searchable_attributes_no_searchable_attributes_ def test_bulk_search_no_queries_return_early(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -144,19 +145,19 @@ def test_bulk_search_no_queries_return_early(self): assert resp['result'] == [] def test_bulk_search_multiple_indexes_and_queries(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, ], auto_refresh=True ) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_2, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id2-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id2-second"}, ], auto_refresh=True ) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_3, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id3-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id3-second"}, @@ -205,7 +206,7 @@ def test_multimodal_tensor_combination_zero_weight(self): } }) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=documents, auto_refresh=True, mappings = {"combo_text_image" : {"type":"multimodal_combination", "weights":{"image_field": 0,"text_field": 1}}} ) @@ -219,7 +220,7 @@ def test_multimodal_tensor_combination_zero_weight(self): assert res['result'][0]["hits"][0]["_score"] == res['result'][0]["hits"][1]["_score"] def test_bulk_search_works_on_uncached_field(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe efgh ", "other field": "baaadd efgh ", "_id": "5678", "finally": "some field efgh "}, @@ -245,7 +246,7 @@ def test_bulk_search_works_on_uncached_field(self): assert self.index_name_3 in index_meta_cache.get_cache() def test_bulk_search_query_boosted(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ { "Title": "The Travels of Marco Polo", @@ -283,7 +284,7 @@ def test_bulk_search_query_boosted(self): self.assertGreater(score_boosted, score_neg_boosted) def test_bulk_search_query_invalid_boosted(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ { "Title": "The Travels of Marco Polo", @@ -310,12 +311,12 @@ def test_bulk_search_query_invalid_boosted(self): pass def test_bulk_search_multiple_indexes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, ], auto_refresh=True) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_2, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id2-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id2-second"}, @@ -343,7 +344,7 @@ def test_bulk_search_multiple_indexes(self): @mock.patch("marqo.tensor_search.tensor_search.bulk_msearch") def test_bulk_search_multiple_queries_single_msearch_request(self, mock_bulk_msearch): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -384,7 +385,8 @@ def run(): content=['one thing'], device='cpu', normalize_embeddings=False, - image_download_headers=None + image_download_headers=None, + model_auth=None ) print("ock_vectorise.call_args_listock_vectorise.call_args_list", mock_vectorise.call_args_list) @@ -395,7 +397,8 @@ def run(): content=['two things'], device='cpu', normalize_embeddings=True, - image_download_headers=None + image_download_headers=None, + model_auth=None ) self.assertEqual(mock_vectorise.call_count, 2) @@ -424,7 +427,8 @@ def run (): content=['one thing', 'two things'], device='cpu', normalize_embeddings=True, - image_download_headers=None + image_download_headers=None, + model_auth=None ) self.assertEqual(mock_vectorise.call_count, 1) return True @@ -432,7 +436,7 @@ def run (): def test_bulk_search_multiple_search_methods(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -456,7 +460,7 @@ def test_bulk_search_multiple_search_methods(self): assert len(tensor_result["hits"]) > 0 def test_bulk_search_highlight_per_search_query(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -481,7 +485,7 @@ def test_bulk_search_highlight_per_search_query(self): @mock.patch("marqo.s2_inference.reranking.rerank.rerank_search_results") def test_bulk_search_rerank_per_search_query(self, mock_rerank_search_results): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -507,7 +511,7 @@ def test_bulk_search_rerank_per_search_query(self, mock_rerank_search_results): assert call_arg['num_highlights'] == 1 def test_bulk_search_rerank_invalid(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, @@ -529,7 +533,7 @@ def test_each_doc_returned_once(self): """TODO: make sure each return only has one doc for each ID, - esp if matches are found in multiple fields """ - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe efgh ", "other field": "baaadd efgh ", "_id": "5678", "finally": "some field efgh "}, @@ -564,7 +568,7 @@ def test_bulk_vector_text_search_long_query_string(self): query_text = """The Guardian is a British daily newspaper. It was founded in 1821 as The Manchester Guardian, and changed its name in 1959.[5] Along with its sister papers The Observer and The Guardian Weekly, The Guardian is part of the Guardian Media Group, owned by the Scott Trust.[6] The trust was created in 1936 to "secure the financial and editorial independence of The Guardian in perpetuity and to safeguard the journalistic freedom and liberal values of The Guardian free from commercial or political interference".[7] The trust was converted into a limited company in 2008, with a constitution written so as to maintain for The Guardian the same protections as were built into the structure of the Scott Trust by its creators. Profits are reinvested in journalism rather than distributed to owners or shareholders.[7] It is considered a newspaper of record in the UK.[8][9] The editor-in-chief Katharine Viner succeeded Alan Rusbridger in 2015.[10][11] Since 2018, the paper's main newsprint sections have been published in tabloid format. As of July 2021, its print edition had a daily circulation of 105,134.[4] The newspaper has an online edition, TheGuardian.com, as well as two international websites, Guardian Australia (founded in 2013) and Guardian US (founded in 2011). The paper's readership is generally on the mainstream left of British political opinion,[12][13][14][15] and the term "Guardian reader" is used to imply a stereotype of liberal, left-wing or "politically correct" views.[3] Frequent typographical errors during the age of manual typesetting led Private Eye magazine to dub the paper the "Grauniad" in the 1960s, a nickname still used occasionally by the editors for self-mockery.[16] """ - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"_id": "12345", "Desc": "The Guardian is newspaper, read in the UK and other places around the world"}, {"_id": "abc12334", "Title": "Grandma Jo's family recipe. ", @@ -578,7 +582,7 @@ def test_bulk_vector_text_search_long_query_string(self): assert len(search_res[0]['hits']) == 2 def test_bulk_vector_text_search_searchable_attributes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "1234"}, @@ -593,7 +597,7 @@ def test_bulk_vector_text_search_searchable_attributes(self): assert list(res["_highlights"].keys()) == ["other field"] def test_bulk_vector_text_search_searchable_attributes_multiple(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "Cool Field 1": "res res res", "_id": "5678"}, @@ -617,7 +621,7 @@ def test_bulk_vector_text_search_searchable_attributes_multiple(self): def test_search_format(self): """Is the result formatted correctly?""" q = "Exact match hehehe" - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "Cool Field 1": "res res res", "_id": "5678"}, @@ -643,7 +647,7 @@ def test_search_format(self): assert search_res['result'][0]["limit"] == 50 def test_result_count_validation(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "Cool Field 1": "res res res", "_id": "5678"}, @@ -678,7 +682,7 @@ def test_result_count_validation(self): assert len(search_results["result"][0]['hits']) >= 1 def test_highlights_tensor(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234"}, @@ -704,7 +708,7 @@ def test_highlights_tensor(self): def test_search_vector_int_field(self): """doesn't error out if there is a random int field""" - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_int": 144}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "my_int": 88}, @@ -718,7 +722,7 @@ def test_search_vector_int_field(self): assert len(s_res["hits"]) > 0 def test_filtering_list_case_tensor(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -760,7 +764,7 @@ def test_filtering_list_case_image(self): settings = {"index_defaults": {"treat_urls_and_pointers_as_images": True, "model": "ViT-B/32"}} tensor_search.create_vector_index(index_name=self.index_name_1, index_settings=settings, config=self.config) hippo_img = 'https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png' - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"img": hippo_img, "abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"img": hippo_img, "abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -793,7 +797,7 @@ def test_filtering_list_case_image(self): def test_filtering(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -814,7 +818,7 @@ def test_filtering(self): assert result[0]["hits"][j]["_id"] in expected_ids[i] def test_filter_spaced_fields(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -864,7 +868,7 @@ def run(): assert mock_config.search_device == "cpu" def test_search_other_types_subsearch(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, auto_refresh=True, docs=[{ "an_int": 1, @@ -887,7 +891,7 @@ def test_search_other_types_top_search(self): "a_bool": True, "some_str": "blah" }] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, auto_refresh=True, docs=docs) for field, to_search in docs[0].items(): @@ -912,7 +916,7 @@ def test_attributes_to_retrieve_vector(self): "abc": "random text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) results = tensor_search.bulk_search( marqo_config=self.config, query=BulkSearchQuery( @@ -939,7 +943,7 @@ def test_attributes_to_retrieve_empty(self): "abc": "random text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) for method in ("LEXICAL", "TENSOR"): search_res = tensor_search.bulk_search( @@ -985,7 +989,7 @@ def test_attributes_to_retrieve_non_existent(self): "abc": "random a text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat a match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) for to_retrieve in [[], ["non existing field name"], ["other field", "non existing field name"]]: for method in ("TENSOR", "LEXICAL"): @@ -1016,7 +1020,7 @@ def test_attributes_to_retrieve_and_searchable_attribs(self): "i_3": {"field_1": " a ", "_id": "i_3", "field_2": "a", "field_3": "a "} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) for to_retrieve, to_search, expected_ids, expected_fields in [ (["field_1"], ["field_3"], ["i_3"], ["field_1"]), @@ -1051,7 +1055,7 @@ def test_limit_results(self): vocab = requests.get(vocab_source).text.splitlines() - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25)))} for _ in range(2000)], auto_refresh=False @@ -1127,11 +1131,10 @@ def test_limit_results_none(self): vocab_source = "https://www.mit.edu/~ecprice/wordlist.10000" vocab = requests.get(vocab_source).text.splitlines() - tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25)))} - for _ in range(700)], auto_refresh=False, processes=4, batch_size=50 + for _ in range(700)], auto_refresh=False), processes=4, batch_size=50 ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) @@ -1165,7 +1168,7 @@ def test_pagination_single_field(self): vocab = requests.get(vocab_source).text.splitlines() num_docs = 2000 - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25))), "_id": str(i) @@ -1291,7 +1294,7 @@ def test_pagination_multi_field_error(self): } ] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=False ) @@ -1363,7 +1366,7 @@ def test_image_search_highlights(self): {"_id": "789", "image_field": url_2}, ] - tensor_search.add_documents( + add_docs_caller( config=self.config, auto_refresh=True, index_name=self.index_name_1, docs=docs ) res = tensor_search.bulk_search( @@ -1390,7 +1393,7 @@ def test_multi_search(self): {"field_a": "Construction and scaffolding equipment", "_id": 'irrelevant_doc'} ] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1430,7 +1433,7 @@ def test_multi_search_images(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1479,7 +1482,7 @@ def test_multi_search_images_edge_cases(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1514,7 +1517,7 @@ def test_multi_search_images_ok_edge_cases(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1552,7 +1555,7 @@ def test_image_search(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1631,7 +1634,7 @@ def test_bulk_multi_search_check_vector(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1717,7 +1720,7 @@ def test_bulk_multi_search_check_vector_multiple_queries(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) diff --git a/tests/tensor_search/test_create_index.py b/tests/tensor_search/test_create_index.py index d9c061350..258864eeb 100644 --- a/tests/tensor_search/test_create_index.py +++ b/tests/tensor_search/test_create_index.py @@ -2,6 +2,7 @@ from typing import Any, Dict from unittest.mock import patch import requests +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.tensor_search.enums import IndexSettingsField, EnvVars from marqo.errors import MarqoApiError, MarqoError, IndexNotFoundError from marqo.tensor_search import tensor_search, configs, backend @@ -9,6 +10,7 @@ from tests.marqo_test import MarqoTestCase from marqo.tensor_search.enums import IndexSettingsField as NsField from unittest import mock +from marqo.tensor_search.models.settings_object import settings_schema from marqo import errors from marqo.errors import InvalidArgError @@ -171,7 +173,8 @@ def test_create_vector_index_default_knn_settings(self): config=self.config, index_name=self.index_name_1, index_settings={ NsField.index_defaults: custom_settings}) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True)) mappings = requests.get( url=self.endpoint + "/" + self.index_name_1 + "/_mapping", verify=False @@ -208,7 +211,8 @@ def test_create_vector_index_custom_knn_settings(self): config=self.config, index_name=self.index_name_1, index_settings={ NsField.index_defaults: custom_settings}) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True)) mappings = requests.get( url=self.endpoint + "/" + self.index_name_1 + "/_mapping", verify=False @@ -466,27 +470,31 @@ def test_field_limits(self): @mock.patch("os.environ", {EnvVars.MARQO_MAX_INDEX_FIELDS: str(lim)}) def run(): res_1 = tensor_search.add_documents( - index_name=self.index_name_1, docs=[ - {f"f{i}": "some content" for i in range(lim)}, - {"_id": "1234", **{f"f{i}": "new content" for i in range(lim)}}, - ], - auto_refresh=True, config=self.config + add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {f"f{i}": "some content" for i in range(lim)}, + {"_id": "1234", **{f"f{i}": "new content" for i in range(lim)}}, + ], + auto_refresh=True), + config=self.config ) assert not res_1['errors'] res_1_2 = tensor_search.add_documents( - index_name=self.index_name_1, docs=[ - {'f0': 'this is fine, but there is no resiliency.'}, - {f"f{i}": "some content" for i in range(lim // 2 + 1)}, - {'f0': 'this is fine. Still no resilieny.'} - ], - auto_refresh=True, config=self.config + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[ + {'f0': 'this is fine, but there is no resiliency.'}, + {f"f{i}": "some content" for i in range(lim // 2 + 1)}, + {'f0': 'this is fine. Still no resilieny.'}], + auto_refresh=True), + config=self.config ) assert not res_1_2['errors'] try: res_2 = tensor_search.add_documents( - index_name=self.index_name_1, docs=[ - {'fx': "blah"} - ], auto_refresh=True, config=self.config + add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[{'fx': "blah"}], + auto_refresh=True), + config=self.config ) raise AssertionError except errors.IndexMaxFieldsError: @@ -504,14 +512,17 @@ def run(): {"f2": 49, "f3": 400.4, "f4": "alien message"} ] res_1 = tensor_search.add_documents( - index_name=self.index_name_1, docs=docs, auto_refresh=True, config=self.config + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=docs, auto_refresh=True), + config=self.config ) assert not res_1['errors'] try: res_2 = tensor_search.add_documents( - index_name=self.index_name_1, docs=[ - {'fx': "blah"} - ], auto_refresh=True, config=self.config + add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {'fx': "blah"} + ], auto_refresh=True), + config=self.config ) raise AssertionError except errors.IndexMaxFieldsError: @@ -535,7 +546,8 @@ def run(): {"f2": 49, "f3": 400.4, "f4": "alien message", "_id": "rkjn"} ] res_1 = tensor_search.add_documents( - index_name=self.index_name_1, docs=docs, auto_refresh=True, config=self.config + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=docs, auto_refresh=True), + config=self.config ) mapping_info = requests.get( self.authorized_url + f"/{self.index_name_1}/_mapping", diff --git a/tests/tensor_search/test_custom_vectors_search.py b/tests/tensor_search/test_custom_vectors_search.py index be882f88a..8e129610d 100644 --- a/tests/tensor_search/test_custom_vectors_search.py +++ b/tests/tensor_search/test_custom_vectors_search.py @@ -1,9 +1,4 @@ -import unittest.mock -import pprint - -import torch - -import marqo.tensor_search.backend +from tests.utils.transition import add_docs_caller from marqo.errors import IndexNotFoundError, InvalidArgError from marqo.tensor_search import tensor_search from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod @@ -31,7 +26,7 @@ def setUp(self): IndexSettingsField.normalize_embeddings: True } }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { "Title": "Horse rider", "text_field": "A rider is riding a horse jumping over the barrier.", diff --git a/tests/tensor_search/test_delete_documents.py b/tests/tensor_search/test_delete_documents.py index ae9fd45fb..bf3fe5995 100644 --- a/tests/tensor_search/test_delete_documents.py +++ b/tests/tensor_search/test_delete_documents.py @@ -11,7 +11,7 @@ from unittest.mock import patch from marqo import errors from marqo.tensor_search import enums - +from tests.utils.transition import add_docs_caller class TestDeleteDocuments(MarqoTestCase): """module that has tests at the tensor_search level""" @@ -33,7 +33,7 @@ def _delete_testing_indices(self): def test_delete_documents(self): # first batch: - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"f1": "cat dog sat mat", "Sydney": "Australia contains Sydney"}, @@ -44,7 +44,7 @@ def test_delete_documents(self): timeout=self.config.timeout, verify=False ).json()["count"] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"hooped": "absolutely ridic", "Darling": "A harbour in Sydney", "_id": "455"}, @@ -65,7 +65,7 @@ def test_delete_documents(self): assert count_post_delete == count0_res def test_delete_docs_format(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"f1": "cat dog sat mat", "Sydney": "Australia contains Sydney", "_id": "1234"}, @@ -85,7 +85,7 @@ def test_delete_docs_format(self): def test_only_specified_documents_are_deleted(self): # Add multiple documents - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"sample_field": "sample value", "_id": "unique_id_1"}, @@ -114,7 +114,7 @@ def test_only_specified_documents_are_deleted(self): def test_delete_multiple_documents(self): # Create an index and add documents tensor_search.create_vector_index(index_name=self.index_name_1, config=self.config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"field1": "value1", "_id": "doc_id_1"}, @@ -141,7 +141,7 @@ def test_delete_multiple_documents(self): def test_document_is_actually_deleted(self): # Add a document - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[{"sample_field": "sample value", "_id": "unique_id"}], auto_refresh=True ) @@ -157,7 +157,7 @@ def test_document_is_actually_deleted(self): def test_multiple_documents_are_actually_deleted(self): # Add multiple documents - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"sample_field": "sample value", "_id": "unique_id_1"}, @@ -215,7 +215,7 @@ def test_delete_documents_with_invalid_ids(self): def test_delete_already_deleted_document(self): # Create an index and add a document tensor_search.create_vector_index(index_name=self.index_name_1, config=self.config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"field1": "value1", "_id": "doc_id_1"}, @@ -246,7 +246,7 @@ def test_delete_already_deleted_document(self): def test_delete_documents_mixed_valid_invalid_ids(self): # Create an index and add documents tensor_search.create_vector_index(index_name=self.index_name_1, config=self.config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"field1": "value1", "_id": "doc_id_1"}, @@ -422,7 +422,7 @@ def run(): tensor_search.create_vector_index( index_name=self.index_name_1, index_settings={"index_defaults": {"model": 'random'}}, config=self.config) # over the limit: - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"_id": x, 'Bad field': "blh "} for x in doc_ids ], diff --git a/tests/tensor_search/test_get_document.py b/tests/tensor_search/test_get_document.py index a966f81f5..113e49924 100644 --- a/tests/tensor_search/test_get_document.py +++ b/tests/tensor_search/test_get_document.py @@ -1,6 +1,6 @@ import functools import pprint - +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.tensor_search import enums from marqo.errors import IndexNotFoundError, InvalidDocumentIdError from marqo.tensor_search import tensor_search @@ -25,12 +25,14 @@ def _delete_testing_indices(self): def test_get_document(self): """Also ensures that the _id is returned""" tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "_id": "123", - "title 1": "content 1", - "desc 2": "content 2. blah blah blah" - }], auto_refresh=True) + tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[ + { + "_id": "123", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }], auto_refresh=True) + ) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123") == { @@ -72,8 +74,10 @@ def test_get_document_vectors_format(self): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) keys = ("title 1", "desc 2") vals = ("content 1", "content 2. blah blah blah") - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "123", **dict(zip(keys, vals))}], auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"_id": "123", **dict(zip(keys, vals))}], + auto_refresh=True) + ) res = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) diff --git a/tests/tensor_search/test_get_documents_by_ids.py b/tests/tensor_search/test_get_documents_by_ids.py index 885c0a2d5..46cb23323 100644 --- a/tests/tensor_search/test_get_documents_by_ids.py +++ b/tests/tensor_search/test_get_documents_by_ids.py @@ -9,6 +9,8 @@ from marqo.tensor_search import tensor_search from tests.marqo_test import MarqoTestCase from unittest import mock +from marqo.tensor_search.models.add_docs_objects import AddDocsParams + class TestGetDocuments(MarqoTestCase): @@ -25,10 +27,13 @@ def _delete_testing_indices(self): pass def test_get_documents_by_ids(self): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "1", "title 1": "content 1"}, {"_id": "2", "title 1": "content 1"}, - {"_id": "3", "title 1": "content 1"} - ], auto_refresh=True) + tensor_search.add_documents( + config=self.config, + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[ + {"_id": "1", "title 1": "content 1"}, {"_id": "2", "title 1": "content 1"}, + {"_id": "3", "title 1": "content 1"} + ], auto_refresh=True) + ) res = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, document_ids=['1', '2', '3'], show_vectors=True) @@ -39,8 +44,9 @@ def test_get_documents_vectors_format(self): keys = [("title 1", "desc 2", "_id"), ("title 1", "desc 2", "_id")] vals = [("content 1", "content 2. blah blah blah", "123"), ("some more content", "some cool desk", "5678")] - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - dict(zip(k, v)) for k, v in zip(keys, vals)], auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[dict(zip(k, v)) for k, v in zip(keys, vals)], + auto_refresh=True)) get_res = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, document_ids=["123", "5678"], show_vectors=True)['results'] @@ -77,8 +83,12 @@ def test_get_document_vectors_non_existent(self): def test_get_document_vectors_resilient(self): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"_id": '456', "title": "alexandra"}, {'_id': '221', 'message': 'hello'}], auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"_id": '456', "title": "alexandra"}, + {'_id': '221', 'message': 'hello'}], + auto_refresh=True) + ) id_reqs = [ (['123', '456'], [False, True]), ([['456', '789'], [True, False]]), ([['456', '789', '221'], [True, False, True]]), ([['vkj', '456', '4891'], [False, True, False]]) @@ -119,8 +129,10 @@ def test_get_documents_env_limit(self): }}) docs = [{"Title": "a", "_id": uuid.uuid4().__str__()} for _ in range(2000)] tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, - docs=docs, auto_refresh=False, batch_size=50, processes=4 + config=self.config, batch_size=50, processes=4, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=docs, auto_refresh=False + ) ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) for max_doc in [0, 1, 2, 5, 10, 100, 1000]: @@ -158,8 +170,11 @@ def test_limit_results_none(self): docs = [{"Title": "a", "_id": uuid.uuid4().__str__()} for _ in range(2000)] tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, - docs=docs, auto_refresh=False, batch_size=50, processes=4 + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=docs, auto_refresh=False + ), + batch_size=50, processes=4 ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) diff --git a/tests/tensor_search/test_get_stats.py b/tests/tensor_search/test_get_stats.py index d4711a689..098a48e6f 100644 --- a/tests/tensor_search/test_get_stats.py +++ b/tests/tensor_search/test_get_stats.py @@ -1,10 +1,6 @@ -import json -import pprint -import time +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.errors import IndexNotFoundError, MarqoError from marqo.tensor_search import tensor_search, constants, index_meta_cache -import unittest -import copy from tests.marqo_test import MarqoTestCase @@ -32,9 +28,10 @@ def test_get_stats_non_empty(self): except IndexNotFoundError as s: pass tensor_search.add_documents( - docs=[{"1": "2"},{"134": "2"},{"14": "62"},], - config=self.config, - index_name=self.index_name_1, - auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + docs=[{"1": "2"},{"134": "2"},{"14": "62"}], + index_name=self.index_name_1, + auto_refresh=True + ) ) assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] == 3 diff --git a/tests/tensor_search/test_image_download_headers.py b/tests/tensor_search/test_image_download_headers.py index 81baddca7..170fef646 100644 --- a/tests/tensor_search/test_image_download_headers.py +++ b/tests/tensor_search/test_image_download_headers.py @@ -2,6 +2,7 @@ Module for testing image download headers. """ import unittest.mock +from marqo.tensor_search.models.add_docs_objects import AddDocsParams # we are renaming get to prevent inf. recursion while mocking get(): from requests import get as requests_get from marqo.tensor_search.models.api_models import BulkSearchQuery @@ -52,9 +53,10 @@ def test_img_download_search(self): config=self.config, index_name=self.index_name_1, index_settings=self.image_index_settings() ) image_download_headers = {"Authorization": "some secret key blah"} - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"_id": "1", "image": self.real_img_url} - ], auto_refresh=True, image_download_headers=image_download_headers) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"_id": "1", "image": self.real_img_url}], + auto_refresh=True, image_download_headers=image_download_headers)) def pass_through_requests_get(url, *args, **kwargs): return requests_get(url, *args, **kwargs) @@ -94,10 +96,11 @@ def run(): image_download_headers = {"Authorization": "some secret key blah"} # Add a document with an image URL - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { "_id": "1", "image": self.real_img_url} - ], auto_refresh=True, image_download_headers=image_download_headers) - + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { "_id": "1", "image": self.real_img_url} + ], auto_refresh=True, image_download_headers=image_download_headers + )) # Check if load_image_from_path was called with the correct headers assert len(mock_load_image_from_path.call_args_list) == 1 call_args, call_kwargs = mock_load_image_from_path.call_args_list[0] @@ -126,12 +129,13 @@ def pass_through_requests_get(url, *args, **kwargs): mock_load_image_from_path.side_effect = pass_through_load_image_from_path with unittest.mock.patch("marqo.s2_inference.clip_utils.load_image_from_path", mock_load_image_from_path): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "_id": "1", "image": test_image_url, - } - ], auto_refresh=True, image_download_headers=image_download_headers) + }], + auto_refresh=True, image_download_headers=image_download_headers)) # Set up the mock GET mock_get = unittest.mock.MagicMock() diff --git a/tests/tensor_search/test_index_meta_cache.py b/tests/tensor_search/test_index_meta_cache.py index 8f762e955..cd0595262 100644 --- a/tests/tensor_search/test_index_meta_cache.py +++ b/tests/tensor_search/test_index_meta_cache.py @@ -1,14 +1,9 @@ import copy import datetime -import pprint import threading import time -import unittest import requests - -import marqo.tensor_search.validation -from marqo.tensor_search.models.index_info import IndexInfo -from marqo.tensor_search.models import index_info +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.tensor_search import tensor_search from marqo.tensor_search import index_meta_cache from marqo.config import Config @@ -73,19 +68,24 @@ def test_search_works_on_cache_clear(self): def test_add_new_fields_preserves_index_cache(self): add_doc_res_1 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True) ) add_doc_res_2 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], - auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], + auto_refresh=True + ) ) index_info_t0 = index_meta_cache.get_cache()[self.index_name_1] # reset cache: index_meta_cache.empty_cache() add_doc_res_3 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"newer field": "ndewr content", + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"newer field": "ndewr content", "goblin": "paradise"}], - auto_refresh=True + auto_refresh=True + ) ) for field in ["newer field", "goblin", "cool field", "abc", "haha"]: assert utils.generate_vector_name(field) \ @@ -94,10 +94,14 @@ def test_add_new_fields_preserves_index_cache(self): def test_delete_removes_index_from_cache(self): """note the implicit index creation""" add_doc_res_1 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + ) ) add_doc_res_2 = tensor_search.add_documents( - config=self.config, index_name=self.index_name_2, docs=[{"abc": "def"}], auto_refresh=True + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_2, docs=[{"abc": "def"}], auto_refresh=True + ) ) assert self.index_name_1 in index_meta_cache.get_cache() tensor_search.delete_index(index_name=self.index_name_1, config=self.config) @@ -118,8 +122,9 @@ def test_lexical_search_caching(self): d1 = {"some doc 1": "some 2 marqo", "field abc": "robodog is not a cat", "_id": "Jupyter_12"} d2 = {"exclude me": "marqo"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d1, d2]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[d0, d1, d2]) + ) # reset cache index_meta_cache.empty_cache() search_res =tensor_search._lexical_search( @@ -137,8 +142,10 @@ def test_get_documents_caching(self): d1 = {"some doc 1": "some 2 marqo", "field abc": "robodog is not a cat", "_id": "Jupyter_12"} d2 = {"exclude me": "marqo"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d1, d2 ]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d1, d2 ]) + ) # reset cache index_meta_cache.empty_cache() search_res = tensor_search.get_document_by_id( @@ -171,8 +178,8 @@ def _simulate_externally_added_docs(self, index_name, docs, check_only_in_extern cache_t0 = copy.deepcopy(index_meta_cache.get_cache()) # mock external party indexing something: tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=docs, auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=docs, auto_refresh=True)) if check_only_in_external_cache is not None: assert ( @@ -199,8 +206,8 @@ def test_search_lexical_externally_created_field(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some field": "Plane 1"}], auto_refresh=True)) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") result = tensor_search.search( @@ -220,8 +227,8 @@ def test_search_vectors_externally_created_field(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"some field": "Plane 1"}], auto_refresh=True)) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") result = tensor_search.search( @@ -238,8 +245,8 @@ def test_search_vectors_externally_created_field_attributes(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some field": "Plane 1"}], auto_refresh=True)) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") assert "brand new field" not in index_meta_cache.get_cache() @@ -255,8 +262,8 @@ def test_search_lexical_externally_created_field_attributes(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some field": "Plane 1"}], auto_refresh=True)) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") assert "brand new field" not in index_meta_cache.get_cache() @@ -275,8 +282,8 @@ def test_vector_search_non_existent_field(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some field": "Plane 1"}], auto_refresh=True)) assert "brand new field" not in index_meta_cache.get_cache() result = tensor_search.search( index_name=self.index_name_1, config=self.config, text="a line of text", @@ -289,8 +296,8 @@ def test_lexical_search_non_existent_field(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some field": "Plane 1"}], auto_refresh=True)) assert "brand new field" not in index_meta_cache.get_cache() # no error: result = tensor_search.search( @@ -303,8 +310,8 @@ def test_search_vectors_externally_created_field_attributes_cache_update(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some field": "Plane 1"}], auto_refresh=True)) time.sleep(2.5) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") @@ -537,8 +544,10 @@ def test_search_index_refresh_on_interval_multi_threaded(self): # we need to search it once, to to get something in the cache, otherwise # the threads will see an empty cache and try to fill it try: - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[{"hi": "hello"}], - auto_refresh=False) + tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[{"hi": "hello"}], + auto_refresh=False)) except IndexNotFoundError: pass @mock.patch('marqo._httprequests.ALLOWED_OPERATIONS', {mock_get}) @@ -607,12 +616,16 @@ def run(): index_settings={"index_defaults": {"model": "random"}}) clear_cache_thread = threading.Thread(target=clear_cache) clear_cache_thread.start() - tensor_search.add_documents(**{ - "config": self.config, "index_name": self.index_name_1, "auto_refresh": True, - "docs": [ - {"Title": "Blah"}, {"Title": "blah2"}, {"Title": "Blah3"}, {"Title": "Blah4"}, - ] - }) + tensor_search.add_documents( + config=self.config, + add_docs_params=AddDocsParams( + **{ + "index_name": self.index_name_1, "auto_refresh": True, + "docs": [ + {"Title": "Blah"}, {"Title": "blah2"}, + {"Title": "Blah3"}, {"Title": "Blah4"}] + }) + ) return True assert run() @@ -643,12 +656,16 @@ def run(): clear_cache_thread = threading.Thread(target=delete_index) clear_cache_thread.start() try: - tensor_search.add_documents(**{ - "config": self.config, "index_name": self.index_name_1, "auto_refresh": True, - "docs": [ - {"Title": "Blah"}, {"Title": "blah2"}, {"Title": "Blah3"}, {"Title": "Blah4"}, - ] - }) + tensor_search.add_documents( + **{"config": self.config}, + add_docs_params=AddDocsParams( + **{ + "index_name": self.index_name_1, "auto_refresh": True, + "docs": [ + {"Title": "Blah"}, {"Title": "blah2"}, + {"Title": "Blah3"}, {"Title": "Blah4"}] + }) + ) raise AssertionError except errors.IndexNotFoundError: pass diff --git a/tests/tensor_search/test_lexical_search.py b/tests/tensor_search/test_lexical_search.py index 0b446f14f..27c37e8d4 100644 --- a/tests/tensor_search/test_lexical_search.py +++ b/tests/tensor_search/test_lexical_search.py @@ -1,14 +1,14 @@ -import pprint import time from marqo.tensor_search import enums, backend from marqo.tensor_search import tensor_search -import unittest import copy from marqo.errors import InvalidArgError, IndexNotFoundError from tests.marqo_test import MarqoTestCase import random import requests import json +from marqo.tensor_search.models.add_docs_objects import AddDocsParams + class TestLexicalSearch(MarqoTestCase): @@ -41,16 +41,17 @@ def strip_marqo_fields(doc, strip_id=False): def test_lexical_search_empty_text(self): tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True) + ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="") assert len(res["hits"]) == 0 assert res["hits"] == [] def test_lexical_search_bad_text_type(self): tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True)) bad_args = [None, 1234, 1.0] for a in bad_args: try: @@ -72,9 +73,12 @@ def test_lexical_search_multiple(self): } d1 = {"title": "Marqo", "some doc 2": "some other thing", "_id": "abcdef"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d1, {"some doc 1": "some 2", "field abc": "robodog is not a cat", "_id": "unusual id"}, - d0]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[d1, + {"some doc 1": "some 2", "field abc": "robodog is not a cat", "_id": "unusual id"}, + d0]) + ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="marqo field", return_doc_ids=True) assert len(res["hits"]) == 2 @@ -95,11 +99,11 @@ def test_lexical_search_single_searchable_attribs(self): d4 = {"Lucy": "Travis", "field lambda": "there is a whole bunch of text here. " "Just a slight mention of a field", "_id": "123"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1 ]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d4, d1 ])) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d3, d2]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d3, d2])) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="marqo field", return_doc_ids=True, searchable_attributes=["field lambda"], result_count=3) @@ -121,11 +125,13 @@ def test_lexical_search_multiple_searchable_attribs(self): d4 = {"Lucy": "Travis", "field lambda": "there is a whole bunch of text here. " "Just a slight mention of a field", "_id": "123"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d4, d1])) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d3, d2]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[d3, d2]) + ) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", return_doc_ids=True, searchable_attributes=["field lambda", "FIELD omega"]) @@ -148,11 +154,13 @@ def test_lexical_search_multiple_searchable_attribs_no_returned_ids(self): d4 = { # SHOULD APPEAR 3rd (LAST) "Lucy": "Travis", "field lambda": "sentence with the word field" } tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[d0, d4, d1]) + ) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d3, d2]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[d3, d2]) + ) time.sleep(1) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field awks", @@ -178,8 +186,8 @@ def test_lexical_search_result_count(self): "Just a slight mention of a field"} d5 = {"some completely irrelevant": "document hehehe"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1, d3, d2]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d4, d1, d3, d2])) r1 = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", return_doc_ids=False, result_count=2 @@ -212,8 +220,8 @@ def test_search_lexical_param(self): "Just a slight mention of a field"} d5 = {"some completely irrelevant": "document hehehe"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1, d3, d2]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d4, d1, d3, d2])) res_lexical_search = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", return_doc_ids=False, searchable_attributes=["field lambda", "FIELD omega"]) @@ -246,8 +254,8 @@ def test_lexical_search_overwriting_doc(self): "Cool field": "Marqo is the best!" } tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d0])) assert [] == tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", return_doc_ids=False)["hits"] @@ -258,8 +266,8 @@ def test_lexical_search_overwriting_doc(self): assert grey_query["hits"][0]["_id"] == a_consistent_id # update doc so it does indeed get returned tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d1]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d1])) cool_query = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", return_doc_ids=True) @@ -282,11 +290,11 @@ def test_lexical_search_filter(self): "Just a slight mention of a field", "day": 190, "_id": "123"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1 ]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d4, d1 ])) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d3, d2]) + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, + docs=[d3, d2])) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="marqo field", return_doc_ids=True, filter_string="title:Marqo OR (Lucy:Travis AND day:>50)" @@ -305,8 +313,10 @@ def test_lexical_search_empty_searchable_attribs(self): d2 = {"some doc 1": "some 2 jnkerkbj", "field abc": "extravagant robodog is not a cat", "_id": "Jupyter_12"} tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d1, d2]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[d0, d1, d2]) + ) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="extravagant", return_doc_ids=True, searchable_attributes=[], result_count=3) @@ -369,8 +379,8 @@ def test_lexical_search_double_quotes(self): fields = ["Field 1", "Field 2", "Field 3"] tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, - docs=docs, auto_refresh=False + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=docs, auto_refresh=False) ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) @@ -430,12 +440,13 @@ def test_lexical_search_double_quotes(self): def test_lexical_search_list(self): tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ - {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, - {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, - {"abc": "some text", "_id": "1235", "my_list": ["tag1", "tag2 some"]}, - {"abc": "some text", "_id": "1001", "my_cool_list": ["b_1", "b2"], "fun list": ['truk', 'car']}, - ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, + {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, + {"abc": "some text", "_id": "1235", "my_list": ["tag1", "tag2 some"]}, + {"abc": "some text", "_id": "1001", "my_cool_list": ["b_1", "b2"], "fun list": ['truk', 'car']}, + ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"])) base_search_args = { 'index_name': self.index_name_1, "config": self.config, "search_method": enums.SearchMethod.LEXICAL @@ -464,12 +475,14 @@ def test_lexical_search_list(self): def test_lexical_search_list_searchable_attr(self): tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ - {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, - {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, - {"abc": "some text", "_id": "1235", "my_list": ["tag1", "tag2 some"]}, - {"abc": "some text", "_id": "1001", "my_cool_list": ["b_1", "b2"], "fun list": ['truk', 'car']}, - ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"]) + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, + {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, + {"abc": "some text", "_id": "1235", "my_list": ["tag1", "tag2 some"]}, + {"abc": "some text", "_id": "1001", "my_cool_list": ["b_1", "b2"], "fun list": ['truk', 'car']}, + ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"]) + ) base_search_args = { 'index_name': self.index_name_1, "config": self.config, "search_method": enums.SearchMethod.LEXICAL, 'text': "tag1" @@ -485,12 +498,13 @@ def test_lexical_search_list_searchable_attr(self): def test_lexical_search_filter_with_dot(self): tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[ {"content": "a man on a horse", "filename" : "Important_File_1.pdf", "_id":"123"}, {"content": "the horse is eating grass", "filename": "Important_File_2.pdf", "_id": "456"}, {"content": "what is the document", "filename": "Important_File_3.pdf", "_id": "789"}, ], auto_refresh=True) + ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="horse", return_doc_ids=True, searchable_attributes=["content"], diff --git a/tests/tensor_search/test_model_auth.py b/tests/tensor_search/test_model_auth.py new file mode 100644 index 000000000..cf7fe8fc9 --- /dev/null +++ b/tests/tensor_search/test_model_auth.py @@ -0,0 +1,999 @@ +"""todos: host a public HF-based CLIP (non-OpenCLIP) model so that we can use it for mocks and tests + +multiprocessing should be tested manually -problem with mocking (deadlock esque) +""" +from marqo.s2_inference.random_utils import Random +from marqo.s2_inference.s2_inference import _convert_vectorized_output +from marqo.tensor_search import tensor_search +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from marqo.tensor_search.models.private_models import S3Auth, ModelAuth, HfAuth +from marqo.errors import InvalidArgError, IndexNotFoundError, BadRequestError +from tests.marqo_test import MarqoTestCase +from marqo.s2_inference.model_downloading.from_s3 import get_s3_model_absolute_cache_path +from marqo.tensor_search.models.external_apis.s3 import S3Location +from unittest import mock +import unittest +import os +from marqo.errors import BadRequestError, ModelNotInCacheError +from marqo.tensor_search.models.api_models import BulkSearchQuery, BulkSearchQueryEntity + + +def fake_vectorise(*args, **_kwargs): + random_model = Random(model_name='blah', embedding_dim=512) + return _convert_vectorized_output(random_model.encode(_kwargs['content'])) + +def _delete_file(file_path): + try: + os.remove(file_path) + except FileNotFoundError: + pass + + +def _get_base_index_settings(): + return { + "index_defaults": { + "treat_urls_and_pointers_as_images": True, + "model": 'my_model', + "normalize_embeddings": True, + # notice model properties aren't here. Each test has to add it + } + } + +class TestModelAuthLoadedS3(MarqoTestCase): + """loads an s3 model loaded index, for tests that don't need to redownload + the model each time """ + + model_abs_path = None + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + index_name_1 = "test-model-auth-index-1" + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + custom_model_name = 'my_model' + device='cpu' + + @classmethod + def setUpClass(cls) -> None: + """Simulates downloading a model from a private and using it in an + add docs call + """ + super().setUpClass() + + cls.endpoint = cls.authorized_url + cls.generic_header = {"Content-type": "application/json"} + + try: + tensor_search.delete_index(config=cls.config, index_name=cls.index_name_1) + except IndexNotFoundError as s: + pass + + cls.model_abs_path = get_s3_model_absolute_cache_path( + S3Location( + Key=cls.s3_object_key, + Bucket=cls.s3_bucket + )) + _delete_file(cls.model_abs_path) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": cls.s3_bucket, + "Key": cls.s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=cls.config, index_name=cls.index_name_1, index_settings=s3_settings) + + public_model_url = "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt" + + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = public_model_url + + # file should not yet exist: + assert not os.path.isfile(cls.model_abs_path) + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: + # Call the function that uses the generate_presigned_url method + res = tensor_search.add_documents(config=cls.config, add_docs_params=AddDocsParams( + index_name=cls.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], + model_auth=ModelAuth( + s3=S3Auth(aws_access_key_id=cls.fake_access_key_id, aws_secret_access_key=cls.fake_secret_key)) + )) + assert not res['errors'] + + # now the file exists + assert os.path.isfile(cls.model_abs_path) + + mock_s3_client.generate_presigned_url.assert_called_with( + 'get_object', + Params={'Bucket': 'your-bucket-name', 'Key': cls.s3_object_key} + ) + mock_boto3_client.assert_called_once_with( + 's3', + aws_access_key_id=cls.fake_access_key_id, + aws_secret_access_key=cls.fake_secret_key, + aws_session_token=None + ) + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass() + _delete_file(cls.model_abs_path) + tensor_search.eject_model(model_name=cls.custom_model_name, device=cls.device) + + def test_after_downloading_auth_doesnt_matter(self): + """on this instance, at least""" + res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}] + )) + assert not res['errors'] + + def test_after_downloading_doesnt_redownload(self): + """on this instance, at least""" + tensor_search.eject_model(model_name=self.custom_model_name, device=self.device) + mods = tensor_search.get_loaded_models()['models'] + assert not any([m['model_name'] == 'my_model' for m in mods]) + mock_req = mock.MagicMock() + with mock.patch('urllib.request.urlopen', mock_req): + res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}] + )) + assert not res['errors'] + mock_req.assert_not_called() + mods = tensor_search.get_loaded_models()['models'] + assert any([m['model_name'] == 'my_model' for m in mods]) + + def test_after_downloading_search_doesnt_redownload(self): + """on this instance, at least""" + tensor_search.eject_model(model_name=self.custom_model_name, device=self.device) + mods = tensor_search.get_loaded_models()['models'] + assert not any([m['model_name'] == 'my_model' for m in mods]) + mock_req = mock.MagicMock() + with mock.patch('urllib.request.urlopen', mock_req): + res = tensor_search.search(config=self.config, + index_name=self.index_name_1, text='hi' + ) + assert 'hits' in res + mock_req.assert_not_called() + + mods = tensor_search.get_loaded_models()['models'] + assert any([m['model_name'] == 'my_model' for m in mods]) + +class TestModelAuth(MarqoTestCase): + + device = 'cpu' + + def setUp(self) -> None: + self.endpoint = self.authorized_url + self.generic_header = {"Content-type": "application/json"} + self.index_name_1 = "test-model-auth-index-1" + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except IndexNotFoundError as s: + pass + + def tearDown(self) -> None: + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except IndexNotFoundError as s: + pass + + def test_model_auth_hf(self): + """ + Does not yet assert that a file is downloaded + """ + hf_object = "some_model.pt" + hf_repo_name = "MyRepo/test-private" + hf_token = "hf_some_secret_key" + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "hf": { + "repo_id": hf_repo_name, + "filename": hf_object, + }, + "auth_required": True + }, + "type": "open_clip", + } + hf_settings = _get_base_index_settings() + hf_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=hf_settings) + + mock_hf_hub_download = mock.MagicMock() + mock_hf_hub_download.return_value = 'cache/path/to/model.pt' + + mock_open_clip_creat_model = mock.MagicMock() + + with unittest.mock.patch('open_clip.create_model_and_transforms', mock_open_clip_creat_model): + with unittest.mock.patch('marqo.s2_inference.model_downloading.from_hf.hf_hub_download', mock_hf_hub_download): + try: + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], + model_auth=ModelAuth(hf=HfAuth(token=hf_token)))) + except BadRequestError as e: + # bad request due to no models actually being loaded + print(e) + pass + + mock_hf_hub_download.assert_called_once_with( + token=hf_token, + repo_id=hf_repo_name, + filename=hf_object + ) + + # is the open clip model being loaded with the expected args? + called_with_expected_args = any( + call.kwargs.get("pretrained") == "cache/path/to/model.pt" + and call.kwargs.get("model_name") == "ViT-B/32" + for call in mock_open_clip_creat_model.call_args_list + ) + assert len(mock_open_clip_creat_model.call_args_list) == 1 + assert called_with_expected_args, "Expected call not found" + + def test_model_auth_s3_search(self): + """The other test load from add_docs, we have to make sure it works for + search""" + + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + model_abs_path = get_s3_model_absolute_cache_path( + S3Location( + Key=s3_object_key, + Bucket=s3_bucket + )) + _delete_file(model_abs_path) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + public_model_url = "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt" + + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = public_model_url + + # file should not yet exist: + assert not os.path.isfile(model_abs_path) + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: + res = tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + model_auth=ModelAuth(s3=S3Auth(aws_access_key_id=fake_access_key_id, aws_secret_access_key=fake_secret_key)) + ) + + assert os.path.isfile(model_abs_path) + + mock_s3_client.generate_presigned_url.assert_called_with( + 'get_object', + Params={'Bucket': 'your-bucket-name', 'Key': s3_object_key} + ) + mock_boto3_client.assert_called_once_with( + 's3', + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key, + aws_session_token=None + ) + _delete_file(model_abs_path) + + def test_model_auth_hf_search(self): + """The other test focused on add_docs. This focuses on search + Does not yet assert that a file is downloaded + """ + hf_object = "some_model.pt" + hf_repo_name = "MyRepo/test-private" + hf_token = "hf_some_secret_key" + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "hf": { + "repo_id": hf_repo_name, + "filename": hf_object, + }, + "auth_required": True + }, + "type": "open_clip", + } + hf_settings = _get_base_index_settings() + hf_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=hf_settings) + + mock_hf_hub_download = mock.MagicMock() + mock_hf_hub_download.return_value = 'cache/path/to/model.pt' + + mock_open_clip_creat_model = mock.MagicMock() + + with unittest.mock.patch('open_clip.create_model_and_transforms', mock_open_clip_creat_model): + with unittest.mock.patch('marqo.s2_inference.model_downloading.from_hf.hf_hub_download', mock_hf_hub_download): + try: + res = tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + model_auth=ModelAuth(hf=HfAuth(token=hf_token))) + except BadRequestError: + # bad request due to no models actually being loaded + pass + + mock_hf_hub_download.assert_called_once_with( + token=hf_token, + repo_id=hf_repo_name, + filename=hf_object + ) + + # is the open clip model being loaded with the expected args? + called_with_expected_args = any( + call.kwargs.get("pretrained") == "cache/path/to/model.pt" + and call.kwargs.get("model_name") == "ViT-B/32" + for call in mock_open_clip_creat_model.call_args_list + ) + assert len(mock_open_clip_creat_model.call_args_list) == 1 + assert called_with_expected_args, "Expected call not found" + + def test_model_auth_mismatch_param_s3_ix(self): + """There isn't validation for the hf because users may download public models this way""" + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + public_model_url = "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt" + hf_token = 'hf_secret_token' + + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = public_model_url + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client): + with self.assertRaises(BadRequestError) as cm: + tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + model_auth=ModelAuth(hf=HfAuth(token=hf_token))) + + self.assertIn("s3 authorisation information is required", str(cm.exception)) + + def test_model_loads_from_all_add_docs_derivatives(self): + """Does it work from add_docs, add_docs orchestrator and add_documents_mp? + """ + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + for add_docs_method, kwargs in [ + (tensor_search.add_documents_orchestrator, {'batch_size': 10}), + ]: + try: + tensor_search.eject_model(model_name='my_model' ,device=self.device) + except ModelNotInCacheError: + pass + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = "https://some_non_existent_model.pt" + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: + with self.assertRaises(BadRequestError) as cm: + with unittest.mock.patch( + 'marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url' + ) as mock_download_pretrained_from_url: + add_docs_method( + config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, + model_auth=ModelAuth(s3=S3Auth( + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key)), + auto_refresh=True, + docs=[{f'Title': "something {i} good"} for i in range(20)] + ), + **kwargs + ) + mock_download_pretrained_from_url.assert_called_once_with( + url='https://some_non_existent_model.pt', cache_dir=None, cache_file_name='secret_model.pt') + mock_s3_client.generate_presigned_url.assert_called_with( + 'get_object', + Params={'Bucket': 'your-bucket-name', 'Key': s3_object_key} + ) + mock_boto3_client.assert_called_once_with( + 's3', + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key, + aws_session_token=None + ) + mock_download_pretrained_from_url.reset_mock() + mock_s3_client.reset_mock() + mock_boto3_client.reset_mock() + + def test_model_loads_from_multi_search(self): + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + random_model = Random(model_name='blah', embedding_dim=512) + + try: + tensor_search.eject_model(model_name='my_model', device=self.device) + except ModelNotInCacheError: + pass + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = "https://some_non_existent_model.pt" + + with unittest.mock.patch('marqo.s2_inference.s2_inference.vectorise', + side_effect=fake_vectorise) as mock_vectorise: + model_auth = ModelAuth( + s3=S3Auth( + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key) + ) + res = tensor_search.search( + index_name=self.index_name_1, + config=self.config, + model_auth=model_auth, + text={ + (f"https://raw.githubusercontent.com/marqo-ai/" + f"marqo-api-tests/mainline/assets/ai_hippo_realistic.png"): 0.3, + 'my text': -1.3 + }, + ) + assert 'hits' in res + mock_vectorise.assert_called() + assert len(mock_vectorise.call_args_list) > 0 + for _args, _kwargs in mock_vectorise.call_args_list: + assert _kwargs['model_properties']['model_location'] == { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + } + assert _kwargs['model_auth'] == model_auth + + def test_model_loads_from_multimodal_combination(self): + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + random_model = Random(model_name='blah', embedding_dim=512) + + + for add_docs_method, kwargs in [ + (tensor_search.add_documents_orchestrator, {'batch_size': 10}), + (tensor_search.add_documents, {}) + ]: + try: + tensor_search.eject_model(model_name='my_model', device=self.device) + except ModelNotInCacheError: + pass + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = "https://some_non_existent_model.pt" + + with unittest.mock.patch('marqo.s2_inference.s2_inference.vectorise', side_effect=fake_vectorise) as mock_vectorise: + model_auth = ModelAuth( + s3=S3Auth( + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key) + ) + res = add_docs_method( + config=self.config, + add_docs_params=AddDocsParams( + index_name=self.index_name_1, + model_auth=model_auth, + auto_refresh=True, + docs=[{ + 'my_combination_field': { + 'my_image': f"https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png", + 'some_text': f"my text {i}"}} for i in range(20)], + mappings={ + "my_combination_field": { + "type": "multimodal_combination", + "weights": { + "my_image": 0.5, + "some_text": 0.5 + } + } + } + ), + **kwargs + ) + if isinstance(res, list): + assert all([not batch_res ['errors'] for batch_res in res]) + else: + assert not res['errors'] + mock_vectorise.assert_called() + for _args, _kwargs in mock_vectorise.call_args_list: + assert _kwargs['model_properties']['model_location'] == { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + } + assert _kwargs['model_auth'] == model_auth + + def test_no_creds_error(self): + """in s3, if there aren't creds""" + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + public_model_url = "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt" + hf_token = 'hf_secret_token' + + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = public_model_url + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client): + with self.assertRaises(BadRequestError) as cm: + tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + ) + self.assertIn("s3 authorisation information is required", str(cm.exception)) + + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client): + with self.assertRaises(BadRequestError) as cm2: + res = tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[{'title': 'blah blah'}] + ) + ) + self.assertIn("s3 authorisation information is required", str(cm2.exception)) + + def test_bad_creds_error_s3(self): + """in s3 if creds aren't valid. Ensure a helpful error""" + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + + model_auth = ModelAuth( + s3=S3Auth( + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key) + ) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + with self.assertRaises(BadRequestError) as cm: + tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + model_auth=model_auth + ) + self.assertIn("403 error when trying to retrieve model from s3", str(cm.exception)) + + with self.assertRaises(BadRequestError) as cm2: + res = tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[{'title': 'blah blah'}], model_auth=model_auth + ) + ) + self.assertIn("403 error when trying to retrieve model from s3", str(cm2.exception)) + + def test_non_existent_hf_location(self): + hf_object = "some_model.pt" + hf_repo_name = "MyRepo/test-private" + hf_token = "hf_some_secret_key" + + model_auth = ModelAuth( + hf=HfAuth(token=hf_token) + ) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "hf": { + "repo_id": hf_repo_name, + "filename": hf_object, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + with self.assertRaises(BadRequestError) as cm: + tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + model_auth=model_auth + ) + + self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) + + with self.assertRaises(BadRequestError) as cm2: + res = tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[{'title': 'blah blah'}], model_auth=model_auth + ) + ) + self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) + + def test_bad_creds_error_hf(self): + """the model and repo do exist, but creds are bad. raises the same type of error + as the previous one. """ + hf_object = "dummy_model.pt" + hf_repo_name = "Marqo/test-private" + hf_token = "hf_some_secret_key" + + model_auth = ModelAuth( + hf=HfAuth(token=hf_token) + ) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "hf": { + "repo_id": hf_repo_name, + "filename": hf_object, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + with self.assertRaises(BadRequestError) as cm: + tensor_search.search( + config=self.config, text='hello', index_name=self.index_name_1, + model_auth=model_auth + ) + self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) + + with self.assertRaises(BadRequestError) as cm2: + res = tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, + docs=[{'title': 'blah blah'}], model_auth=model_auth + ) + ) + self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) + + def test_bulk_search(self): + """Does it work with bulk search, including multi search + """ + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + + model_auth = ModelAuth( + s3=S3Auth( + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key) + ) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + for bulk_search_query in [ + BulkSearchQuery(queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, q="match", searchMethod="TENSOR", + modelAuth=model_auth + ), + BulkSearchQueryEntity( + index=self.index_name_1, q={"random text": 0.5, "other_text": -0.3}, + searchableAttributes=["abc"], searchMethod="TENSOR", + modelAuth=model_auth + ), + ]), + BulkSearchQuery(queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, q={"random text": 0.5, "other_text": -0.3}, + searchableAttributes=["abc"], searchMethod="TENSOR", + modelAuth=model_auth + ), + ]) + ]: + try: + tensor_search.eject_model(model_name='my_model' ,device=self.device) + except ModelNotInCacheError: + pass + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = "https://some_non_existent_model.pt" + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: + with self.assertRaises(InvalidArgError) as cm: + with unittest.mock.patch( + 'marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url' + ) as mock_download_pretrained_from_url: + tensor_search.bulk_search( + query=bulk_search_query, + marqo_config=self.config, + ) + mock_download_pretrained_from_url.assert_called_once_with( + url='https://some_non_existent_model.pt', cache_dir=None, cache_file_name='secret_model.pt') + mock_s3_client.generate_presigned_url.assert_called_with( + 'get_object', + Params={'Bucket': 'your-bucket-name', 'Key': s3_object_key} + ) + mock_boto3_client.assert_called_once_with( + 's3', + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key, + aws_session_token=None + ) + + mock_download_pretrained_from_url.reset_mock() + mock_s3_client.reset_mock() + mock_boto3_client.reset_mock() + + def test_bulk_search_vectorise(self): + """are the calls to vectorise expected? work with bulk search, including multi search + """ + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + + model_auth = ModelAuth( + s3=S3Auth( + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key) + ) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) + + for bulk_search_query in [ + BulkSearchQuery(queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, q="match", searchMethod="TENSOR", + modelAuth=model_auth + ), + BulkSearchQueryEntity( + index=self.index_name_1, q={"random text": 0.5, "other_text": -0.3}, + searchableAttributes=["abc"], searchMethod="TENSOR", + modelAuth=model_auth + ), + ]), + BulkSearchQuery(queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, q={"random text": 0.5, "other_text": -0.3}, + searchableAttributes=["abc"], searchMethod="TENSOR", + modelAuth=model_auth + ), + ]) + ]: + try: + tensor_search.eject_model(model_name='my_model' ,device=self.device) + except ModelNotInCacheError: + pass + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = "https://some_non_existent_model.pt" + + with unittest.mock.patch('marqo.s2_inference.s2_inference.vectorise', + side_effect=fake_vectorise) as mock_vectorise: + tensor_search.bulk_search( + query=bulk_search_query, + marqo_config=self.config, + ) + mock_vectorise.assert_called() + for _args, _kwargs in mock_vectorise.call_args_list: + assert _kwargs['model_properties']['model_location'] == { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, + }, + "auth_required": True + } + assert _kwargs['model_auth'] == model_auth + + mock_vectorise.reset_mock() + + def test_lexical_with_auth(self): + """should just skip""" + + def test_public_s3_no_auth(self): + """ + TODO + """ + + def test_public_hf_no_auth(self): + """ + TODO + """ + + def test_open_clip_reg_clip(self): + """both normal and open clip + TODO: normal CLIP + """ + + + + + + + + + diff --git a/tests/tensor_search/test_model_auth_cuda.py b/tests/tensor_search/test_model_auth_cuda.py new file mode 100644 index 000000000..943924925 --- /dev/null +++ b/tests/tensor_search/test_model_auth_cuda.py @@ -0,0 +1,130 @@ +"""todos: get a public HF-based ViT model so that we can use it for mocks and tests + +multiprocessing should be tested manually -problem with mocking (deadlock esque) +""" +from marqo.tensor_search import tensor_search +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from marqo.tensor_search.models.private_models import S3Auth, ModelAuth, HfAuth +from marqo.errors import InvalidArgError, IndexNotFoundError, BadRequestError +from tests.marqo_test import MarqoTestCase +from marqo.s2_inference.model_downloading.from_s3 import get_s3_model_absolute_cache_path +from marqo.tensor_search.models.external_apis.s3 import S3Location +from unittest import mock +from tests.tensor_search.test_model_auth import _delete_file, _get_base_index_settings +import unittest +import os +import torch +import pytest +from marqo.errors import BadRequestError + + +@pytest.mark.largemodel +@pytest.mark.skipif(torch.cuda.is_available() is False, reason="We skip the large model test if we don't have cuda support") +class TestModelAuthLoadedS3(MarqoTestCase): + """loads an s3 model loaded index, for tests """ + + model_abs_path = None + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + index_name_1 = "test-model-auth-index-1" + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' + custom_model_name = 'my_model' + device = 'cuda' + + @classmethod + def setUpClass(cls) -> None: + """Simulates downloading a model from a private and using it in an + add docs call + """ + super().setUpClass() + + cls.endpoint = cls.authorized_url + cls.generic_header = {"Content-type": "application/json"} + + try: + tensor_search.delete_index(config=cls.config, index_name=cls.index_name_1) + except IndexNotFoundError as s: + pass + + cls.model_abs_path = get_s3_model_absolute_cache_path( + S3Location( + Key=cls.s3_object_key, + Bucket=cls.s3_bucket + )) + _delete_file(cls.model_abs_path) + + model_properties = { + "name": "ViT-B/32", + "dimensions": 512, + "model_location": { + "s3": { + "Bucket": cls.s3_bucket, + "Key": cls.s3_object_key, + }, + "auth_required": True + }, + "type": "open_clip", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=cls.config, index_name=cls.index_name_1, index_settings=s3_settings) + + public_model_url = "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt" + + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() + + # Mock the generate_presigned_url method of the mock Boto3 client with a real OpenCLIP model, so that + # the rest of the logic works. + mock_s3_client.generate_presigned_url.return_value = public_model_url + + # file should not yet exist: + assert not os.path.isfile(cls.model_abs_path) + + with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: + # Call the function that uses the generate_presigned_url method + res = tensor_search.add_documents(config=cls.config, add_docs_params=AddDocsParams( + index_name=cls.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], + device=cls.device, + model_auth=ModelAuth( + s3=S3Auth(aws_access_key_id=cls.fake_access_key_id, aws_secret_access_key=cls.fake_secret_key)) + )) + assert not res['errors'] + + assert os.path.isfile(cls.model_abs_path) + + mock_s3_client.generate_presigned_url.assert_called_with( + 'get_object', + Params={'Bucket': 'your-bucket-name', 'Key': cls.s3_object_key} + ) + mock_boto3_client.assert_called_once_with( + 's3', + aws_access_key_id=cls.fake_access_key_id, + aws_secret_access_key=cls.fake_secret_key, + aws_session_token=None + ) + + @classmethod + def tearDownClass(cls) -> None: + super().tearDownClass() + _delete_file(cls.model_abs_path) + + def test_after_downloading_auth_doesnt_matter(self): + """on this instance, at least""" + res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}], device=self.device + )) + assert not res['errors'] + + def test_after_downloading_doesnt_redownload(self): + """on this instance, at least""" + tensor_search.eject_model(model_name=self.custom_model_name, device=self.device) + mock_req = mock.MagicMock() + with mock.patch('urllib.request.urlopen', mock_req): + res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}], + device=self.device + )) + assert not res['errors'] + mock_req.assert_not_called() diff --git a/tests/tensor_search/test_multimodal_tensor_combination.py b/tests/tensor_search/test_multimodal_tensor_combination.py index 7e09e8f47..eca15652f 100644 --- a/tests/tensor_search/test_multimodal_tensor_combination.py +++ b/tests/tensor_search/test_multimodal_tensor_combination.py @@ -1,9 +1,5 @@ import unittest.mock -import pprint - -import torch - -import marqo.tensor_search.backend +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.errors import IndexNotFoundError, InvalidArgError from marqo.tensor_search import tensor_search from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod @@ -56,20 +52,23 @@ def test_add_documents(self): }, "_id": "0" }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - expected_doc, - - # this is just a dummy one - { - "Title": "Horse rider", - "text_field": "A rider is riding a horse jumping over the barrier.", - "image_field": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", - "_id": "1" - }, - ], mappings = {"combo_text_image" :{"type": "multimodal_combination", "weights" : { - "text" : 0.5, "image" : 0.8} - }},auto_refresh=True) - + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + expected_doc, + # this is just a dummy one + { + "Title": "Horse rider", + "text_field": "A rider is riding a horse jumping over the barrier.", + "image_field": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", + "_id": "1" + }, + ], + mappings = { + "combo_text_image": {"type": "multimodal_combination", "weights" : { + "text" : 0.5, "image" : 0.8} + }}, + auto_refresh=True) + ) added_doc = tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="0", show_vectors=True) for key, value in expected_doc.items(): @@ -101,9 +100,13 @@ def get_score(document): } }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[document], - auto_refresh=True, mappings = {"combo_text_image" : {"type":"multimodal_combination", - "weights":{"image_field":0.5,"text_field":0.5}}}) + tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[document], + auto_refresh=True, mappings = {"combo_text_image" : {"type":"multimodal_combination", + "weights": {"image_field":0.5, "text_field":0.5}}} + ) + ) self.assertEqual(1, tensor_search.get_stats(config=self.config, index_name=self.index_name_1)[ "numberOfDocuments"]) res = tensor_search.search(config=self.config, index_name=self.index_name_1, @@ -140,66 +143,66 @@ def test_multimodal_tensor_combination_tensor_value(self): } }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { "combo_text_image": { "text_field_1": "A rider is riding a horse jumping over the barrier.", "text_field_2": "What is the best to wear on the moon?", "image_field_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", "image_field_2": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + }, + "_id":"c1" }, - "_id":"c1" - }, - - { + { "combo_text_image": { "text_field_1": "A rider is riding a horse jumping over the barrier.", "image_field_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", "text_field_2": "What is the best to wear on the moon?", "image_field_2": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + }, + "_id": "c2" }, - "_id": "c2" - }, - - { + { "combo_text_image": { "image_field_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", "image_field_2": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", "text_field_1": "A rider is riding a horse jumping over the barrier.", "text_field_2": "What is the best to wear on the moon?", + }, + "_id": "c3" }, - "_id": "c3" - }, - - { + { "combo_text_image": { "image_field_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", "text_field_1": "A rider is riding a horse jumping over the barrier.", "text_field_2": "What is the best to wear on the moon?", "image_field_2": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + }, + "_id": "c4" }, - "_id": "c4" - }, - - { - "text_field_1": "A rider is riding a horse jumping over the barrier.", - "_id": "1" - }, - { - "text_field_2": "What is the best to wear on the moon?", - "_id": "2" - }, - { - "image_field_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", - "_id": "3" - }, - { - "image_field_2": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", - "_id": "4" - }, - - ], auto_refresh=True, mappings = {"combo_text_image" : {"type":"multimodal_combination", - "weights":{"text_field_1": 0.32,"text_field_2": 0, "image_field_1" : -0.48, "image_field_2": 1.34}}}) + { + "text_field_1": "A rider is riding a horse jumping over the barrier.", + "_id": "1" + }, + { + "text_field_2": "What is the best to wear on the moon?", + "_id": "2" + }, + { + "image_field_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", + "_id": "3" + }, + { + "image_field_2": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + "_id": "4" + }], + auto_refresh=True, + mappings = { + "combo_text_image" : { + "type":"multimodal_combination", + "weights":{"text_field_1": 0.32,"text_field_2": 0, "image_field_1" : -0.48, "image_field_2": 1.34}}} + )) combo_tensor_1 = np.array(tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="c1", @@ -251,10 +254,13 @@ def get_score(document): } }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[document], - auto_refresh=True, mappings = {"combo_text_image" : {"type":"multimodal_combination", - "weights":{"image_field": 0,"text_field": 1}}}) - + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[document], + auto_refresh=True, mappings = { + "combo_text_image" : { + "type": "multimodal_combination", + "weights": {"image_field": 0,"text_field": 1}}} + )) self.assertEqual(1, tensor_search.get_stats(config=self.config, index_name=self.index_name_1)[ "numberOfDocuments"]) res = tensor_search.search(config=self.config, index_name=self.index_name_1, @@ -299,29 +305,32 @@ def pass_through_multimodal(*arg, **kwargs): @unittest.mock.patch("marqo.tensor_search.tensor_search.vectorise_multimodal_combination_field", mock_multimodal_combination) def run(): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "combo_text_image": { - "text_field": "A rider is riding a horse jumping over the barrier.", - "image_field": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "combo_text_image": { + "text_field": "A rider is riding a horse jumping over the barrier.", + "image_field": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", + }, + "_id": "123", }, - "_id": "123", - }, - - { - "combo_text_image": { - "text_field" : "test-text-two.", - "image_field":"https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + { + "combo_text_image": { + "text_field" : "test-text-two.", + "image_field":"https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + }, + "_id": "234", }, - "_id": "234", - }, - - { # a normal doc - "combo_text_image_test": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", - "_id": "534", - } - ], mappings = {"combo_text_image" : {"type":"multimodal_combination", - "weights":{"image_field": 0.5,"text_field": 0.5}}}, auto_refresh=True) + { # a normal doc + "combo_text_image_test": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + "_id": "534", + }], + mappings = { + "combo_text_image" : { + "type":"multimodal_combination", + "weights": {"image_field": 0.5,"text_field": 0.5}}}, + auto_refresh=True + )) # first multimodal-doc real_fied_0, field_content_0 = [call_args for call_args, call_kwargs @@ -361,14 +370,17 @@ def test_multimodal_field_content_dictionary_validation(self): }) # invalid field_content int - res_0 = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + res_0 = tensor_search.add_documents(config=self.config, + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[ { "combo_text_image": { "A rider is riding a horse jumping over the barrier." : 0.5, "image_field" : 0.5, }, "_id": "123", - }, ], mappings=self.mappings, auto_refresh=True) + }], + mappings=self.mappings, auto_refresh=True) + ) assert res_0["errors"] assert not json.loads(requests.get(url = f"{self.endpoint}/{self.index_name_1}/_doc/123", verify=False).text)["found"] @@ -379,15 +391,18 @@ def test_multimodal_field_content_dictionary_validation(self): pass # invalid field content dict - res_1 = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "combo_text_image": { - "text_field": "A rider is riding a horse jumping over the barrier.", - "image_field": {"image_url" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", - }, - }, - "_id": "123", - }, ], mappings=self.mappings, auto_refresh=True) + res_1 = tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "combo_text_image": { + "text_field": "A rider is riding a horse jumping over the barrier.", + "image_field": {"image_url" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + }, + }, + "_id": "123", + }], + mappings=self.mappings, auto_refresh=True)) assert res_1["errors"] assert not json.loads(requests.get(url = f"{self.endpoint}/{self.index_name_1}/_doc/123", verify=False).text)["found"] try: @@ -397,7 +412,8 @@ def test_multimodal_field_content_dictionary_validation(self): pass # invalid field name format - res_2 = tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + res_2 = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "combo_text_image": { "text_field" : "A rider is riding a horse jumping over the barrier.", @@ -405,7 +421,9 @@ def test_multimodal_field_content_dictionary_validation(self): }, "_id": "123", - }, ], mappings = self.mappings, auto_refresh=True) + }], + mappings = self.mappings, + auto_refresh=True)) assert res_2["errors"] assert not json.loads(requests.get(url = f"{self.endpoint}/{self.index_name_1}/_doc/123", verify=False).text)["found"] try: @@ -519,27 +537,28 @@ def pass_through_vectorise(*arg, **kwargs): @unittest.mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "combo_text_image": { - "text_0": "A rider is riding a horse jumping over the barrier_0.", - "text_1":"A rider is riding a horse jumping over the barrier_1.", - "text_2":"A rider is riding a horse jumping over the barrier_2.", - "text_3":"A rider is riding a horse jumping over the barrier_3.", - "text_4":"A rider is riding a horse jumping over the barrier_4.", - "image_0" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg", - "image_1" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", - "image_2" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", - "image_3" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image3.jpg", - "image_4" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - }, - "_id": "111", - }, - - ], mappings = {"combo_text_image" :{"type":"multimodal_combination", "weights":{ - "text_0" : 0.1, "text_1" : 0.1, "text_2" : 0.1, "text_3" : 0.1, "text_4" : 0.1, - "image_0" : 0.1,"image_1" : 0.1,"image_2" : 0.1,"image_3" : 0.1,"image_4" : 0.1, - }}}, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "combo_text_image": { + "text_0": "A rider is riding a horse jumping over the barrier_0.", + "text_1":"A rider is riding a horse jumping over the barrier_1.", + "text_2":"A rider is riding a horse jumping over the barrier_2.", + "text_3":"A rider is riding a horse jumping over the barrier_3.", + "text_4":"A rider is riding a horse jumping over the barrier_4.", + "image_0" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg", + "image_1" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", + "image_2" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + "image_3" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image3.jpg", + "image_4" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + }, + "_id": "111", + }], + mappings = {"combo_text_image" :{"type":"multimodal_combination", "weights":{ + "text_0" : 0.1, "text_1" : 0.1, "text_2" : 0.1, "text_3" : 0.1, "text_4" : 0.1, + "image_0" : 0.1,"image_1" : 0.1,"image_2" : 0.1,"image_3" : 0.1,"image_4" : 0.1, + }}}, + auto_refresh=True)) # Ensure the doc is added assert tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="111") # Ensure that vectorise is only called twice @@ -577,27 +596,30 @@ def pass_through_vectorise(*arg, **kwargs): @unittest.mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - { - "combo_text_image": { - "text_0": "A rider is riding a horse jumping over the barrier_0.", - "text_1":"A rider is riding a horse jumping over the barrier_1.", - "text_2":"A rider is riding a horse jumping over the barrier_2.", - "text_3":"A rider is riding a horse jumping over the barrier_3.", - "text_4":"A rider is riding a horse jumping over the barrier_4.", - "image_0" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg", - "image_1" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", - "image_2" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", - "image_3" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image3.jpg", - "image_4" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - }, - "_id": "111", - }, - - ], mappings = {"combo_text_image" :{"type":"multimodal_combination", "weights":{ - "text_0" : 0.1, "text_1" : 0.1, "text_2" : 0.1, "text_3" : 0.1, "text_4" : 0.1, - "image_0" : 0.1,"image_1" : 0.1,"image_2" : 0.1,"image_3" : 0.1,"image_4" : 0.1, - }}}, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "combo_text_image": { + "text_0": "A rider is riding a horse jumping over the barrier_0.", + "text_1":"A rider is riding a horse jumping over the barrier_1.", + "text_2":"A rider is riding a horse jumping over the barrier_2.", + "text_3":"A rider is riding a horse jumping over the barrier_3.", + "text_4":"A rider is riding a horse jumping over the barrier_4.", + "image_0" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image0.jpg", + "image_1" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image1.jpg", + "image_2" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", + "image_3" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image3.jpg", + "image_4" : "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + }, + "_id": "111", + }], + mappings = { + "combo_text_image": {"type":"multimodal_combination", "weights":{ + "text_0" : 0.1, "text_1" : 0.1, "text_2" : 0.1, "text_3" : 0.1, "text_4" : 0.1, + "image_0" : 0.1,"image_1" : 0.1,"image_2" : 0.1,"image_3" : 0.1,"image_4" : 0.1, + }}}, + auto_refresh=True) + ) # Ensure the doc is added assert tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="111") # Ensure that vectorise is only called twice @@ -634,7 +656,8 @@ def pass_through_load_image_from_path(*arg, **kwargs): @unittest.mock.patch("marqo.s2_inference.clip_utils.load_image_from_path", mock_load_image_from_path) def run(): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ { "combo_text_image": { "text_0": "A rider is riding a horse jumping over the barrier_0.", @@ -649,18 +672,17 @@ def run(): "image_4": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", }, "_id": "111", - }, - - ], mappings={"combo_text_image": {"type": "multimodal_combination", "weights": { - "text_0": 0.1, "text_1": 0.1, "text_2": 0.1, "text_3": 0.1, "text_4": 0.1, - "image_0": 0.1, "image_1": 0.1, "image_2": 0.1, "image_3": 0.1, "image_4": 0.1, - }}}, auto_refresh=True) + }], + mappings={ + "combo_text_image": {"type": "multimodal_combination", "weights": { + "text_0": 0.1, "text_1": 0.1, "text_2": 0.1, "text_3": 0.1, "text_4": 0.1, + "image_0": 0.1, "image_1": 0.1, "image_2": 0.1, "image_3": 0.1, "image_4": 0.1, + }}}, + auto_refresh=True)) assert tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="111") # Ensure that vectorise is only called twice assert len(mock_load_image_from_path.call_args_list) == 5 - return True - assert run() def test_lexical_search_on_multimodal_combination(self): @@ -672,39 +694,42 @@ def test_lexical_search_on_multimodal_combination(self): IndexSettingsField.normalize_embeddings: False } }) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"Title": "Extravehicular Mobility Unit (EMU)", + "Description": "The EMU is a spacesuit that provides environmental protection", + "_id": "article_591", + "Genre": "Science", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "lexical_field": "search me please",}} + ], + mappings={ + "my_combination_field": { + "type": "multimodal_combination", + "weights": { + "my_image": 0.5, + "some_text": 0.5, + "lexical_field": 0.1, + "additional_field" : 0.2, + } + }}, + auto_refresh=True + )) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "Extravehicular Mobility Unit (EMU)", - "Description": "The EMU is a spacesuit that provides environmental protection", - "_id": "article_591", - "Genre": "Science", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "lexical_field": "search me please",}},], - mappings={ - "my_combination_field": { - "type": "multimodal_combination", - "weights": { - "my_image": 0.5, - "some_text": 0.5, - "lexical_field": 0.1, - "additional_field" : 0.2, - } - }} - , auto_refresh=True) - - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "text", - "Description": "text_2", - "_id": "article_592", - "Genre": "text", - "my_combination_field": { - "my_image_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text_1": "hello there", - "lexical_field_1": "no no no", - "additional_field_1" : "test_search here"}} - ], + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + { + "Title": "text", + "Description": "text_2", + "_id": "article_592", + "Genre": "text", + "my_combination_field": { + "my_image_1": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text_1": "hello there", + "lexical_field_1": "no no no", + "additional_field_1" : "test_search here"}}], mappings={ "my_combination_field": { "type": "multimodal_combination", @@ -714,10 +739,9 @@ def test_lexical_search_on_multimodal_combination(self): "lexical_field_1": 0.1, "additional_field_1" : 0.2, } - }} - ,auto_refresh=True) - - + }}, + auto_refresh=True) + ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="search me please") assert res["hits"][0]["_id"] == "article_591" @@ -734,17 +758,20 @@ def test_overwrite_multimodal_tensor_field(self): } }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "Extravehicular Mobility Unit (EMU)", - "Description": "The EMU is a spacesuit that provides environmental protection", - "_id": "article_591", - "Genre": "Science", - "my_combination_field": "dummy" - },] - , auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"Title": "Extravehicular Mobility Unit (EMU)", + "Description": "The EMU is a spacesuit that provides environmental protection", + "_id": "article_591", + "Genre": "Science", + "my_combination_field": "dummy" + }], + auto_refresh=True + )) try: - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ {"Title": "text", "Description": "text_2", "_id": "article_592", @@ -764,9 +791,8 @@ def test_overwrite_multimodal_tensor_field(self): "lexical_field_1": 0.1, "additional_field_1" : 0.2, } - }} - ,auto_refresh=True) - + }}, + auto_refresh=True)) raise AssertionError except MarqoWebError: pass @@ -782,39 +808,37 @@ def test_search_with_filtering_and_infer_image_false(self): }) tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "Extravehicular Mobility Unit (EMU)", - "_id": "0", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "filter_field": "test_this_0", }}, - - {"Title": "Extravehicular Mobility Unit (EMU)", - "_id": "1", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "filter_field": "test_this_1", }}, - - {"Title": "Extravehicular Mobility Unit (EMU)", - "_id": "2", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "filter_field": "test_this_2", }}, - ], - mappings={ - "my_combination_field": { - "type": "multimodal_combination", - "weights": { - "my_image": 0.5, - "some_text": 0.5, - "filter_field": 0, + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"Title": "Extravehicular Mobility Unit (EMU)", + "_id": "0", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "filter_field": "test_this_0", }}, + {"Title": "Extravehicular Mobility Unit (EMU)", + "_id": "1", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "filter_field": "test_this_1", }}, + {"Title": "Extravehicular Mobility Unit (EMU)", + "_id": "2", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "filter_field": "test_this_2", }}], + mappings={ + "my_combination_field": { + "type": "multimodal_combination", + "weights": { + "my_image": 0.5, + "some_text": 0.5, + "filter_field": 0, } - }} - , auto_refresh=True) - + }}, + auto_refresh=True + )) res_exist_0 = tensor_search.search(index_name=self.index_name_1, config=self.config, text = "", filter="my_combination_field.filter_field: test_this_0") @@ -839,41 +863,38 @@ def test_index_info_cache_update(self): IndexSettingsField.normalize_embeddings: False } }) - tensor_search.add_documents( - config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "Extravehicular Mobility Unit (EMU)", - "_id": "0", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "filter_field": "test_this_0", }}, - - {"Title": "what is this", - "_id": "1", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "filter_field": "test_this_1", }}, - - {"Title": "have a test", - "_id": "2", - "my_combination_field": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "filter_field": "test_this_2", }}, - ], - mappings={ - "my_combination_field": { - "type": "multimodal_combination", - "weights": { - "my_image": 0.5, - "some_text": 0.5, - "filter_field": 0, - } - }} - , auto_refresh=True) - + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"Title": "Extravehicular Mobility Unit (EMU)", + "_id": "0", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "filter_field": "test_this_0", }}, + {"Title": "what is this", + "_id": "1", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "filter_field": "test_this_1", }}, + + {"Title": "have a test", + "_id": "2", + "my_combination_field": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "filter_field": "test_this_2"}}], + mappings={ + "my_combination_field": { + "type": "multimodal_combination", + "weights": { + "my_image": 0.5, + "some_text": 0.5, + "filter_field": 0, + } + }}, + auto_refresh=True)) pre_res_0 = tensor_search.search(index_name=self.index_name_1, config=self.config, text = "", filter="my_combination_field.filter_field: test_this_0") pre_res_1 = tensor_search.search(index_name=self.index_name_1, config=self.config, @@ -902,49 +923,51 @@ def test_duplication_in_child_fields(self): } }) - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "Extravehicular Mobility Unit (EMU)", - "Description": "The EMU is a spacesuit that provides environmental protection", - "_id": "article_591", - "Genre": "Science", - "my_combination_field_0": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "hello there", - "lexical_field": "search me please", }}, ], - mappings={ - "my_combination_field_0": { - "type": "multimodal_combination", - "weights": { - "my_image": 0.5, - "some_text": 0.5, - "lexical_field": 0.1, - } - }} - , auto_refresh=True) - - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - {"Title": "text", - "Description": "text_2", - "_id": "article_592", - "Genre": "text", - "my_combination_field_1": { - "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", - "some_text": "marqo is good", - "lexical_field": "no no no", - "additional_field": "i can hear you"}} - ], - mappings={ - "my_combination_field_1": { - "type": "multimodal_combination", - "weights": { - "my_image": 0.5, - "some_text": 0.5, - "lexical_field": 0.1, - "additional_field": 0.2, - } - }} - , auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"Title": "Extravehicular Mobility Unit (EMU)", + "Description": "The EMU is a spacesuit that provides environmental protection", + "_id": "article_591", + "Genre": "Science", + "my_combination_field_0": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "hello there", + "lexical_field": "search me please"}}], + mappings={ + "my_combination_field_0": { + "type": "multimodal_combination", + "weights": { + "my_image": 0.5, + "some_text": 0.5, + "lexical_field": 0.1, + } + }}, + auto_refresh=True)) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + {"Title": "text", + "Description": "text_2", + "_id": "article_592", + "Genre": "text", + "my_combination_field_1": { + "my_image": "https://raw.githubusercontent.com/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image4.jpg", + "some_text": "marqo is good", + "lexical_field": "no no no", + "additional_field": "i can hear you"}} + ], + mappings={ + "my_combination_field_1": { + "type": "multimodal_combination", + "weights": { + "my_image": 0.5, + "some_text": 0.5, + "lexical_field": 0.1, + "additional_field": 0.2, + } + }}, + auto_refresh=True) + ) true_text_fields = tensor_search.get_index_info(self.config, index_name=self.index_name_1).get_true_text_properties() # 3 from multimodal_field_0, 4 from multimodal_field_1, 3 common fields assert len(true_text_fields) == 10 @@ -973,11 +996,13 @@ def test_multimodal_combination_open_search_chunks(self): res = tensor_search.add_documents( self.config, - docs = [test_doc], - auto_refresh=True, index_name=self.index_name_1, - mappings={"my_combination_field": {"type":"multimodal_combination", "weights":{ - "text":0.5, "image":0.5 - }}} + add_docs_params=AddDocsParams( + docs = [test_doc], + auto_refresh=True, index_name=self.index_name_1, + mappings={"my_combination_field": {"type":"multimodal_combination", "weights":{ + "text":0.5, "image":0.5 + }}} + ) ) doc_w_facets = tensor_search.get_document_by_id( @@ -1058,11 +1083,14 @@ def test_multimodal_child_fields_order(self): } with patch("numpy.mean", wraps=np.mean) as mock_mean: - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - doc, doc_1, doc_2, doc_3 - ], mappings={"combo_text_image": {"type": "multimodal_combination", - "weights": {"image_field_1": 0.2, "image_field_2": -1, "text_field_1": 0.38, "text_field_2": 0}}}, auto_refresh=True) - + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[ + doc, doc_1, doc_2, doc_3 + ], mappings={"combo_text_image": {"type": "multimodal_combination", + "weights": {"image_field_1": 0.2, "image_field_2": -1, + "text_field_1": 0.38, "text_field_2": 0}}}, + auto_refresh=True) + ) args_list = [args[0] for args in mock_mean.call_args_list] combined_tensor = np.squeeze(np.mean(args_list[0][0], axis = 0)) @@ -1126,11 +1154,16 @@ def test_multimodal_child_fields_order_from_os(self): } with patch("numpy.mean", wraps=np.mean) as mock_mean: - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ - doc, doc_1, doc_2, doc_3 - ], mappings={"combo_text_image": {"type": "multimodal_combination", - "weights": {"image_field_1": 0.2, "image_field_2": -1, - "text_field_1": 0.38, "text_field_2": 0}}}, auto_refresh=True) + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=[doc, doc_1, doc_2, doc_3], + mappings={ + "combo_text_image": { + "type": "multimodal_combination", + "weights": { + "image_field_1": 0.2, "image_field_2": -1, + "text_field_1": 0.38, "text_field_2": 0}}}, + auto_refresh=True) + ) docs = tensor_search.get_documents_by_ids( config=self.config, document_ids=["d0", "d1", "d2", "d3"], index_name=self.index_name_1, show_vectors=True) diff --git a/tests/tensor_search/test_parallel.py b/tests/tensor_search/test_parallel.py index b9340bc3a..d45587fe4 100644 --- a/tests/tensor_search/test_parallel.py +++ b/tests/tensor_search/test_parallel.py @@ -1,11 +1,9 @@ from marqo.errors import IndexNotFoundError -import unittest -import copy from marqo.tensor_search import parallel import torch from tests.marqo_test import MarqoTestCase from marqo.tensor_search import tensor_search - +from marqo.tensor_search.models.add_docs_objects import AddDocsParams class TestAddDocumentsPara(MarqoTestCase): """ @@ -44,6 +42,7 @@ def test_add_documents_parallel(self) -> None: data = [{'text':f'something {str(i)}', '_id': str(i)} for i in range(100)] - res = tensor_search.add_documents_orchestrator(config=self.config, index_name=self.index_name_1, docs=data, - batch_size=10, processes=1, auto_refresh=True) + res = tensor_search.add_documents_orchestrator(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=data, auto_refresh=True), + batch_size=10, processes=1) res = tensor_search.search(config=self.config, text='something', index_name=self.index_name_1) \ No newline at end of file diff --git a/tests/tensor_search/test_score_modifiers_search.py b/tests/tensor_search/test_score_modifiers_search.py index 08a04a6a1..63bd7a6f8 100644 --- a/tests/tensor_search/test_score_modifiers_search.py +++ b/tests/tensor_search/test_score_modifiers_search.py @@ -1,6 +1,6 @@ import copy import unittest.mock -from unittest.mock import patch +from tests.utils.transition import add_docs_caller from marqo.errors import IndexNotFoundError, InvalidArgError from marqo.tensor_search import tensor_search from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod @@ -150,7 +150,7 @@ def test_search_result_not_affected_if_fields_not_exist(self): }, ] - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) @@ -199,7 +199,7 @@ def get_expected_score(self, doc, ori_score, score_modifiers): def test_search_score_modified_as_expected(self): documents = self.test_score_documents - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) normal_res = tensor_search.search(config=self.config, index_name=self.index_name, @@ -223,7 +223,7 @@ def test_search_score_modified_as_expected(self): def test_search_score_modified_as_expected_with_filter(self): documents = self.test_score_documents - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) normal_res = tensor_search.search(config=self.config, index_name=self.index_name, @@ -248,7 +248,7 @@ def test_search_score_modified_as_expected_with_filter(self): def test_search_score_modified_as_expected_with_searchable_attributes(self): documents = self.test_score_documents - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) normal_res = tensor_search.search(config=self.config, index_name=self.index_name, @@ -275,7 +275,7 @@ def test_search_score_modified_as_expected_with_attributes_to_retrieve(self): invalid_fields = copy.deepcopy(documents[0]) del invalid_fields["_id"] - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) normal_res = tensor_search.search(config=self.config, index_name=self.index_name, @@ -353,7 +353,7 @@ def test_search_score_modified_as_expected_with_skipped_fields(self): }, ] - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) @@ -520,7 +520,7 @@ def test_expected_error_raised(self): }, ] - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) for invalid_score_modifiers in invalid_score_modifiers_list: @@ -540,7 +540,7 @@ def test_normal_query_body_is_called(self): "filter": "original" }, ] - tensor_search.add_documents(config=self.config, index_name=self.index_name, docs=documents, + add_docs_caller(config=self.config, index_name=self.index_name, docs=documents, non_tensor_fields=["multiply_1", "multiply_2", "add_1", "add_2", "filter"], auto_refresh=True) def pass_create_normal_tensor_search_query(*arg, **kwargs): @@ -572,7 +572,7 @@ def test_bulk_search_not_support_score_modifiers(self): {"field_name": "add_2", "weight": 1, }] } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=index_name, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, diff --git a/tests/tensor_search/test_search.py b/tests/tensor_search/test_search.py index cfbc7c027..76a35e127 100644 --- a/tests/tensor_search/test_search.py +++ b/tests/tensor_search/test_search.py @@ -1,7 +1,8 @@ import math import os import sys - +from tests.utils.transition import add_docs_caller +from marqo.tensor_search.models.add_docs_objects import AddDocsParams from unittest import mock from marqo.s2_inference.s2_inference import vectorise import numpy as np @@ -44,7 +45,7 @@ def test_each_doc_returned_once(self): """TODO: make sure each return only has one doc for each ID, - esp if matches are found in multiple fields """ - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe efgh ", "other field": "baaadd efgh ", "_id": "5678", "finally": "some field efgh "}, @@ -60,7 +61,7 @@ def test_each_doc_returned_once(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '2'}}) def test_search_with_excessive_searchable_attributes(self): with self.assertRaises(InvalidArgError): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "1234"}, @@ -73,7 +74,7 @@ def test_search_with_excessive_searchable_attributes(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '2'}}) def test_search_with_allowable_num_searchable_attributes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "1234"}, @@ -85,7 +86,7 @@ def test_search_with_allowable_num_searchable_attributes(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': None}}) def test_search_with_searchable_attributes_max_attributes_is_none(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "1234"}, @@ -98,7 +99,7 @@ def test_search_with_searchable_attributes_max_attributes_is_none(self): @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': f"{sys.maxsize}"}}) def test_search_with_no_searchable_attributes_but_max_searchable_attributes_env_set(self): with self.assertRaises(InvalidArgError): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "1234"}, @@ -126,7 +127,7 @@ def test_vector_search_long_query_string(self): query_text = """The Guardian is a British daily newspaper. It was founded in 1821 as The Manchester Guardian, and changed its name in 1959.[5] Along with its sister papers The Observer and The Guardian Weekly, The Guardian is part of the Guardian Media Group, owned by the Scott Trust.[6] The trust was created in 1936 to "secure the financial and editorial independence of The Guardian in perpetuity and to safeguard the journalistic freedom and liberal values of The Guardian free from commercial or political interference".[7] The trust was converted into a limited company in 2008, with a constitution written so as to maintain for The Guardian the same protections as were built into the structure of the Scott Trust by its creators. Profits are reinvested in journalism rather than distributed to owners or shareholders.[7] It is considered a newspaper of record in the UK.[8][9] The editor-in-chief Katharine Viner succeeded Alan Rusbridger in 2015.[10][11] Since 2018, the paper's main newsprint sections have been published in tabloid format. As of July 2021, its print edition had a daily circulation of 105,134.[4] The newspaper has an online edition, TheGuardian.com, as well as two international websites, Guardian Australia (founded in 2013) and Guardian US (founded in 2011). The paper's readership is generally on the mainstream left of British political opinion,[12][13][14][15] and the term "Guardian reader" is used to imply a stereotype of liberal, left-wing or "politically correct" views.[3] Frequent typographical errors during the age of manual typesetting led Private Eye magazine to dub the paper the "Grauniad" in the 1960s, a nickname still used occasionally by the editors for self-mockery.[16] """ - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"_id": "12345", "Desc": "The Guardian is newspaper, read in the UK and other places around the world"}, {"_id": "abc12334", "Title": "Grandma Jo's family recipe. ", @@ -138,7 +139,7 @@ def test_vector_search_long_query_string(self): ) def test_vector_search_all_highlights(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe efgh ", "other field": "baaadd efgh ", "_id": "5678", "finally": "some field efgh "}, @@ -153,7 +154,7 @@ def test_vector_search_all_highlights(self): assert len(res["highlights"]) == 3 def test_vector_search_n_highlights(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe efgh ", "other field": "baaadd efgh ", "_id": "5678", "finally": "some field efgh "}, @@ -168,7 +169,7 @@ def test_vector_search_n_highlights(self): assert len(res["highlights"]) == 2 def test_vector_search_searchable_attributes(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, {"abc": "random text", "other field": "Close match hehehe", "_id": "1234"}, @@ -182,7 +183,7 @@ def test_vector_search_searchable_attributes(self): assert list(res["_highlights"].keys()) == ["other field"] def test_vector_search_searchable_attributes_multiple(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "Cool Field 1": "res res res", "_id": "5678"}, @@ -200,7 +201,7 @@ def test_vector_search_searchable_attributes_multiple(self): def test_tricky_search(self): """We ran into bugs with this doc""" - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs = [ { 'text': 'In addition to NiS collection fire assay for a five element PGM suite, the samples will undergo research quality analyses for a wide range of elements, including the large ion. , the rare earth elements, high field strength elements, sulphur and selenium.hey include 55 elements of the periodic system: O, Si, Al, Ti, B, C, all the alkali and alkaline-earth metals, the halogens, and many of the rare elements.', @@ -219,7 +220,7 @@ def test_tricky_search(self): def test_search_format(self): """Is the result formatted correctly?""" q = "Exact match hehehe" - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "Cool Field 1": "res res res", "_id": "5678"}, @@ -258,7 +259,7 @@ def test_search_format_empty(self): assert search_res["limit"] > 0 def test_result_count_validation(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "Cool Field 1": "res res res", "_id": "5678"}, @@ -300,7 +301,7 @@ def test_result_count_validation(self): assert len(search_res['hits']) >= 1 def test_highlights_tensor(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234"}, @@ -321,7 +322,7 @@ def test_highlights_tensor(self): assert "_highlights" not in hit def test_highlights_lexical(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234"}, @@ -343,7 +344,7 @@ def test_highlights_lexical(self): def test_search_lexical_int_field(self): """doesn't error out if there is a random int field""" - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_int": 144}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "my_int": 88}, @@ -356,7 +357,7 @@ def test_search_lexical_int_field(self): def test_search_vector_int_field(self): """doesn't error out if there is a random int field""" - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_int": 144}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "my_int": 88}, @@ -368,7 +369,7 @@ def test_search_vector_int_field(self): assert len(s_res["hits"]) > 0 def test_filtering_list_case_tensor(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -405,7 +406,7 @@ def test_filtering_list_case_tensor(self): assert len(res_should_only_match_keyword_good["hits"]) == 1 def test_filtering_list_case_lexical(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -444,7 +445,7 @@ def test_filtering_list_case_image(self): settings = {"index_defaults": {"treat_urls_and_pointers_as_images": True, "model": "ViT-B/32"}} tensor_search.create_vector_index(index_name=self.index_name_1, index_settings=settings, config=self.config) hippo_img = 'https://raw.githubusercontent.com/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png' - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"img": hippo_img, "abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"img": hippo_img, "abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -483,7 +484,7 @@ def test_filtering_list_case_image(self): assert len(res_lex_none["hits"]) == 0 def test_filtering(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -561,7 +562,7 @@ def test_filtering(self): )["hits"]) def test_filter_spaced_fields(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -590,7 +591,7 @@ def test_filter_spaced_fields(self): assert res_float['hits'][0]['_id'] == "344" def test_filtering_bad_syntax(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "some text", "other field": "baaadd", "_id": "5678", "my_string": "b"}, {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, @@ -627,7 +628,7 @@ def run(): assert kwargs["device"] == "cuda:123" def test_search_other_types_subsearch(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, auto_refresh=True, docs=[{ "an_int": 1, @@ -651,7 +652,7 @@ def test_search_other_types_top_search(self): "a_bool": True, "some_str": "blah" }] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, auto_refresh=True, docs=docs) for field, to_search in docs[0].items(): @@ -666,7 +667,7 @@ def test_search_other_types_top_search(self): ) def test_lexical_filtering(self): - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ { "doc title": "The captain bravely lead her followers into battle." @@ -719,7 +720,7 @@ def test_attributes_to_retrieve_vector(self): "abc": "random text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) search_res = tensor_search.search( config=self.config, index_name=self.index_name_1, text="Exact match hehehe", @@ -741,7 +742,7 @@ def test_attributes_to_retrieve_lexical(self): "abc": "random text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) search_res = tensor_search.search( config=self.config, index_name=self.index_name_1, text="Exact match hehehe", @@ -762,7 +763,7 @@ def test_attributes_to_retrieve_empty(self): "abc": "random text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) for method in ("LEXICAL", "TENSOR"): search_res = tensor_search.search( @@ -793,7 +794,7 @@ def test_attributes_to_retrieve_non_existent(self): "abc": "random a text", "other field": "Close match hehehe"}, "9000": {"Cool Field 1": "somewhat a match", "_id": "9000", "other field": "weewowow"} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) for to_retrieve in [[], ["non existing field name"], ["other field", "non existing field name"]]: for method in ("TENSOR", "LEXICAL"): @@ -817,7 +818,7 @@ def test_attributes_to_retrieve_and_searchable_attribs(self): "i_3": {"field_1": " a ", "_id": "i_3", "field_2": "a", "field_3": "a "} } - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=list(docs.values()), auto_refresh=True) for to_retrieve, to_search, expected_ids, expected_fields in [ (["field_1"], ["field_3"], ["i_3"], ["field_1"]), @@ -839,7 +840,7 @@ def test_attributes_to_retrieve_and_searchable_attribs(self): ) == relevant_fields def test_attributes_to_retrieve_non_list(self): - tensor_search.add_documents(config=self.config, index_name=self.index_name_1, + add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[{"cool field 111": "this is some content"}], auto_refresh=True) for method in ("TENSOR", "LEXICAL"): @@ -859,7 +860,7 @@ def test_limit_results(self): vocab = requests.get(vocab_source).text.splitlines() - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25)))} for _ in range(2000)], auto_refresh=False @@ -899,9 +900,10 @@ def test_limit_results_none(self): vocab = requests.get(vocab_source).text.splitlines() tensor_search.add_documents_orchestrator( - config=self.config, index_name=self.index_name_1, - docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25)))} - for _ in range(700)], auto_refresh=False, processes=4, batch_size=50 + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25)))} + for _ in range(700)], auto_refresh=False), + processes=4, batch_size=50 ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) @@ -926,7 +928,7 @@ def test_pagination_single_field(self): vocab = requests.get(vocab_source).text.splitlines() num_docs = 2000 - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[{"Title": "a " + (" ".join(random.choices(population=vocab, k=25))), "_id": str(i) @@ -1032,7 +1034,7 @@ def test_pagination_multi_field_error(self): } ] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=False ) @@ -1092,7 +1094,7 @@ def test_image_search_highlights(self): {"_id": "789", "image_field": url_2}, ] - tensor_search.add_documents( + add_docs_caller( config=self.config, auto_refresh=True, index_name=self.index_name_1, docs=docs ) res = tensor_search.search( @@ -1112,7 +1114,7 @@ def test_multi_search(self): {"field_a": "Construction and scaffolding equipment", "_id": 'irrelevant_doc'} ] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1150,7 +1152,7 @@ def test_multi_search_images(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1199,7 +1201,7 @@ def test_multi_search_check_vector(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1277,7 +1279,7 @@ def test_multi_search_images_edge_cases(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1310,7 +1312,7 @@ def test_multi_search_images_ok_edge_cases(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1330,7 +1332,7 @@ def test_multi_search_images_lexical(self): {"field_a": "Some text about a weird forest", "_id": 'artefact_hippo'} ] - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) @@ -1368,7 +1370,7 @@ def test_image_search(self): } tensor_search.create_vector_index( config=self.config, index_name=self.index_name_1, index_settings=image_index_config) - tensor_search.add_documents( + add_docs_caller( config=self.config, index_name=self.index_name_1, docs=docs, auto_refresh=True ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utils/transition.py b/tests/utils/transition.py new file mode 100644 index 000000000..52372a0ac --- /dev/null +++ b/tests/utils/transition.py @@ -0,0 +1,15 @@ +""" +Our test suite is quite brittle. This helps the unit test suite navigate +refactoring transitions in Marqo +""" +from marqo.tensor_search.tensor_search import add_documents +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from marqo.config import Config + +def add_docs_caller(config: Config, **kwargs): + """This represents the call signature of add_documents at commit + https://github.com/marqo-ai/marqo/commit/a884c840020e5f75b85b3d534b235a4a4b8f05b5 + + New tests should NOT use this, and should call add_docs directly + """ + return add_documents(config=config, add_docs_params=AddDocsParams(**kwargs)) \ No newline at end of file