Skip to content

Commit

Permalink
Merge pull request #167 from marqo-ai/patch_non_tensor_field
Browse files Browse the repository at this point in the history
passing non_tensor_fields to _batch
  • Loading branch information
pandu-k authored Nov 9, 2022
2 parents 8d76f0b + f97a6ee commit 76adc10
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def add_documents_orchestrator(
raise errors.InvalidArgError("Batch size can't be less than 1!")
logger.info(f"batch_size={batch_size} and processes={processes} - batching using a single process")
return _batch_request(config=config, index_name=index_name, dataset=docs, device=device,
batch_size=batch_size, verbose=False)
batch_size=batch_size, verbose=False, non_tensor_fields=non_tensor_fields)


def _batch_request(config: Config, index_name: str, dataset: List[dict],
batch_size: int = 100, verbose: bool = True, device=None,
update_mode: str = 'replace') -> List[Dict[str, Any]]:
update_mode: str = 'replace', non_tensor_fields: List[str] = []) -> List[Dict[str, Any]]:
"""Batch by the number of documents"""
logger.info(f"starting batch ingestion in sizes of {batch_size}")

Expand All @@ -280,7 +280,7 @@ def verbosely_add_docs(i, docs):
res = add_documents(
config=config, index_name=index_name,
docs=docs, auto_refresh=False, device=device,
update_mode=update_mode
update_mode=update_mode, non_tensor_fields=non_tensor_fields
)
total_batch_time = datetime.datetime.now() - t0
num_docs = len(docs)
Expand Down
76 changes: 75 additions & 1 deletion tests/tensor_search/test_add_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from marqo.errors import IndexNotFoundError, InvalidArgError, BadRequestError
from marqo.tensor_search import tensor_search, index_meta_cache, backend
from tests.marqo_test import MarqoTestCase
import time


class TestAddDocuments(MarqoTestCase):
Expand All @@ -24,6 +25,14 @@ def setUp(self) -> None:
except IndexNotFoundError as s:
pass

def tearDown(self) -> None:
self.index_name_1 = "my-test-index-1"
try:
tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
except IndexNotFoundError as s:
pass


def _match_all(self, index_name, verbose=True):
"""Helper function"""
res = requests.get(
Expand Down Expand Up @@ -876,7 +885,7 @@ def test_put_documents_orchestrator(self):
{"_id": "789", "Temp": 12.5},
],
auto_refresh=True, update_mode='update', processes=4, batch_size=1)

time.sleep(3)
updated_doc = tensor_search.get_document_by_id(
config=self.config, index_name=self.index_name_1, document_id='789'
)
Expand Down Expand Up @@ -955,3 +964,68 @@ def run():
assert items[0]['result'] in ['created', 'updated']
return True
assert run()

def test_no_tensor_field_replace(self):
# test replace and update workflows
tensor_search.add_documents(
self.config,
docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}],
auto_refresh=True, index_name=self.index_name_1
)
tensor_search.add_documents(
self.config,
docs=[{"_id": "123", "myfield": "mydata"}],
auto_refresh=True, index_name=self.index_name_1,
non_tensor_fields=["myfield"]
)
doc_w_facets = tensor_search.get_document_by_id(
self.config, index_name=self.index_name_1, document_id='123', show_vectors=True)
assert doc_w_facets[TensorField.tensor_facets] == []
assert 'myfield2' not in doc_w_facets

def test_no_tensor_field_update(self):
# test replace and update workflows
tensor_search.add_documents(
self.config,
docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}],
auto_refresh=True, index_name=self.index_name_1
)
tensor_search.add_documents(
self.config,
docs=[{"_id": "123", "myfield": "mydata"}],
auto_refresh=True, index_name=self.index_name_1,
non_tensor_fields=["myfield"], update_mode='update'
)
doc_w_facets = tensor_search.get_document_by_id(
self.config, index_name=self.index_name_1, document_id='123', show_vectors=True)
assert len(doc_w_facets[TensorField.tensor_facets]) == 1
assert 'myfield2' in doc_w_facets[TensorField.tensor_facets][0]
assert 'myfield' in doc_w_facets
assert 'myfield2' in doc_w_facets

def test_no_tensor_field_on_empty_ix(self):
tensor_search.add_documents(
self.config,
docs=[{"_id": "123", "myfield": "mydata"}],
auto_refresh=True, index_name=self.index_name_1,
non_tensor_fields=["myfield"]
)
doc_w_facets = tensor_search.get_document_by_id(
self.config, index_name=self.index_name_1, document_id='123', show_vectors=True)
assert doc_w_facets[TensorField.tensor_facets] == []
assert 'myfield' in doc_w_facets

def test_no_tensor_field_on_empty_ix_other_field(self):
tensor_search.add_documents(
self.config,
docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata"}],
auto_refresh=True, index_name=self.index_name_1,
non_tensor_fields=["myfield"]
)
doc_w_facets = tensor_search.get_document_by_id(
self.config, index_name=self.index_name_1, document_id='123', show_vectors=True)
assert len(doc_w_facets[TensorField.tensor_facets]) == 1
assert 'myfield2' in doc_w_facets[TensorField.tensor_facets][0]
assert 'myfield' not in doc_w_facets[TensorField.tensor_facets][0]
assert 'myfield' in doc_w_facets
assert 'myfield2' in doc_w_facets

0 comments on commit 76adc10

Please sign in to comment.