diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 63586d7cc..f93fc7be6 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -254,12 +254,12 @@ def add_documents_orchestrator( raise errors.InvalidArgError("Batch size can't be less than 1!") logger.info(f"batch_size={batch_size} and processes={processes} - batching using a single process") return _batch_request(config=config, index_name=index_name, dataset=docs, device=device, - batch_size=batch_size, verbose=False) + batch_size=batch_size, verbose=False, non_tensor_fields=non_tensor_fields) def _batch_request(config: Config, index_name: str, dataset: List[dict], batch_size: int = 100, verbose: bool = True, device=None, - update_mode: str = 'replace') -> List[Dict[str, Any]]: + update_mode: str = 'replace', non_tensor_fields: List[str] = []) -> List[Dict[str, Any]]: """Batch by the number of documents""" logger.info(f"starting batch ingestion in sizes of {batch_size}") @@ -280,7 +280,7 @@ def verbosely_add_docs(i, docs): res = add_documents( config=config, index_name=index_name, docs=docs, auto_refresh=False, device=device, - update_mode=update_mode + update_mode=update_mode, non_tensor_fields=non_tensor_fields ) total_batch_time = datetime.datetime.now() - t0 num_docs = len(docs) diff --git a/tests/tensor_search/test_add_documents.py b/tests/tensor_search/test_add_documents.py index 976cbe3c8..230f88248 100644 --- a/tests/tensor_search/test_add_documents.py +++ b/tests/tensor_search/test_add_documents.py @@ -11,6 +11,7 @@ from marqo.errors import IndexNotFoundError, InvalidArgError, BadRequestError from marqo.tensor_search import tensor_search, index_meta_cache, backend from tests.marqo_test import MarqoTestCase +import time class TestAddDocuments(MarqoTestCase): @@ -24,6 +25,14 @@ def setUp(self) -> None: except IndexNotFoundError as s: pass + def tearDown(self) -> None: + self.index_name_1 = "my-test-index-1" + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except IndexNotFoundError as s: + pass + + def _match_all(self, index_name, verbose=True): """Helper function""" res = requests.get( @@ -876,7 +885,7 @@ def test_put_documents_orchestrator(self): {"_id": "789", "Temp": 12.5}, ], auto_refresh=True, update_mode='update', processes=4, batch_size=1) - + time.sleep(3) updated_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id='789' ) @@ -955,3 +964,68 @@ def run(): assert items[0]['result'] in ['created', 'updated'] return True assert run() + + def test_no_tensor_field_replace(self): + # test replace and update workflows + tensor_search.add_documents( + self.config, + docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], + auto_refresh=True, index_name=self.index_name_1 + ) + tensor_search.add_documents( + self.config, + docs=[{"_id": "123", "myfield": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"] + ) + doc_w_facets = tensor_search.get_document_by_id( + self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) + assert doc_w_facets[TensorField.tensor_facets] == [] + assert 'myfield2' not in doc_w_facets + + def test_no_tensor_field_update(self): + # test replace and update workflows + tensor_search.add_documents( + self.config, + docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], + auto_refresh=True, index_name=self.index_name_1 + ) + tensor_search.add_documents( + self.config, + docs=[{"_id": "123", "myfield": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"], update_mode='update' + ) + doc_w_facets = tensor_search.get_document_by_id( + self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) + assert len(doc_w_facets[TensorField.tensor_facets]) == 1 + assert 'myfield2' in doc_w_facets[TensorField.tensor_facets][0] + assert 'myfield' in doc_w_facets + assert 'myfield2' in doc_w_facets + + def test_no_tensor_field_on_empty_ix(self): + tensor_search.add_documents( + self.config, + docs=[{"_id": "123", "myfield": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"] + ) + doc_w_facets = tensor_search.get_document_by_id( + self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) + assert doc_w_facets[TensorField.tensor_facets] == [] + assert 'myfield' in doc_w_facets + + def test_no_tensor_field_on_empty_ix_other_field(self): + tensor_search.add_documents( + self.config, + docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata"}], + auto_refresh=True, index_name=self.index_name_1, + non_tensor_fields=["myfield"] + ) + doc_w_facets = tensor_search.get_document_by_id( + self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) + assert len(doc_w_facets[TensorField.tensor_facets]) == 1 + assert 'myfield2' in doc_w_facets[TensorField.tensor_facets][0] + assert 'myfield' not in doc_w_facets[TensorField.tensor_facets][0] + assert 'myfield' in doc_w_facets + assert 'myfield2' in doc_w_facets