Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dupe IDs are handled when use_existing_tensors=True #390

Merged
merged 8 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/marqo/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __str__(self) -> str:

class MarqoWebError(Exception):

status_code: int = None
status_code: int = 500
error_type: str = None
message: str = None
code: str = None
Expand Down
37 changes: 24 additions & 13 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def _infer_opensearch_data_type(

def _get_chunks_for_field(field_name: str, doc_id: str, doc):
# Find the chunks with a specific __field_name in a doc
# Note: for a chunkless doc (nothing was tensorised) --> doc["_source"]["__chunks"] == []
return [chunk for chunk in doc["_source"]["__chunks"] if chunk["__field_name"] == field_name]


Expand Down Expand Up @@ -456,8 +457,12 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres
f"images for {batch_size} docs using {image_download_thread_count} threads ")

if update_mode == 'replace' and use_existing_tensors:
# Get existing documents
doc_ids = [doc["_id"] for doc in docs if "_id" in doc]
doc_ids = []

# Iterate through the list in reverse, only latest doc with dupe id gets added.
for i in range(len(docs)-1, -1, -1):
if ("_id" in docs[i]) and (docs[i]["_id"] not in doc_ids):
doc_ids.append(docs[i]["_id"])
existing_docs = _get_documents_for_upsert(config=config, index_name=index_name, document_ids=doc_ids)

for i, doc in enumerate(docs):
Expand Down Expand Up @@ -497,6 +502,8 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres
# have IDs:
elif len(matching_doc) == 0:
existing_doc = {"found": False}
else:
raise errors.InternalError(message= f"Upsert: found {len(matching_doc)} matching docs for {doc_id} when only 1 or 0 should have been found.")
else:
indexing_instructions["update"]["_id"] = doc_id

Expand Down Expand Up @@ -541,7 +548,6 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres
# Check if content of this field changed. If no, skip all chunking and vectorisation
if ((update_mode == 'replace') and use_existing_tensors and existing_doc["found"]
and (field in existing_doc["_source"]) and (existing_doc["_source"][field] == field_content)):
# logger.info(f"Using existing vectors for doc {doc_id}, field {field}. Content remains unchanged.")
chunks_to_append = _get_chunks_for_field(field_name=field, doc_id=doc_id, doc=existing_doc)

# Chunk and vectorise, since content changed.
Expand Down Expand Up @@ -651,9 +657,7 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres
TensorField.field_content: text_chunk,
TensorField.field_name: field
})

# Add chunks_to_append along with doc metadata to total chunks


elif isinstance(field_content, dict):
if mappings[field]["type"]=="multimodal_combination":
combo_chunk, combo_document_is_valid, unsuccessful_doc_to_append, combo_vectorise_time_to_add,\
Expand All @@ -669,9 +673,11 @@ def add_documents(config: Config, index_name: str, docs: List[dict], auto_refres
if field not in new_obj_fields:
new_obj_fields[field] = set()
new_obj_fields[field] = new_obj_fields[field].union(new_fields_from_multimodal_combination)
# TODO: we may want to use chunks_to_append here to make it uniform with use_existing_tensors and normal vectorisation
chunks.append({**combo_chunk, **chunk_values_for_filtering})
continue


# Add chunks_to_append along with doc metadata to total chunks
for chunk in chunks_to_append:
chunks.append({**chunk, **chunk_values_for_filtering})

Expand Down Expand Up @@ -893,6 +899,7 @@ def _get_documents_for_upsert(
f"set by the environment variable `{EnvVars.MARQO_MAX_RETRIEVABLE_DOCS}`")

# Chunk Docs (get field name, field content, vectors)

chunk_docs = [
{"_index": index_name, "_id": doc_id,
"_source": {"include": [f"__chunks.__field_content", f"__chunks.__field_name", f"__chunks.__vector_*"]}}
Expand All @@ -913,19 +920,21 @@ def _get_documents_for_upsert(

# Combine the 2 query results (loop through each doc id)
combined_result = []
for doc_id in document_ids:

for doc_id in valid_doc_ids:
# There should always be 2 results per doc.
result_list = [doc for doc in res["docs"] if doc["_id"] == doc_id]

if len(result_list) == 0:
continue
if len(result_list) not in (2, 0):
raise errors.MarqoWebError(f"Bad request for existing documents. "
raise errors.InternalError(f"Internal error fetching old documents. "
f"There are {len(result_list)} results for doc id {doc_id}.")

for result in result_list:
if result["found"]:
doc_in_results = True
if result["_source"]["__chunks"] == []:
if ("__chunks" in result["_source"]) and (result["_source"]["__chunks"] == []):
res_data = result
else:
res_chunks = result
Expand All @@ -934,12 +943,14 @@ def _get_documents_for_upsert(
dummy_res = result
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: may be more appropriate to call this something like not_found_res

break

# Put the chunks list in res_data, so it's complete
# Put the chunks list in res_data, so it contains all doc data
if doc_in_results:
res_data["_source"]["__chunks"] = res_chunks["_source"]["__chunks"]
# Only add chunks if not a chunkless doc
if res_chunks["_source"]:
res_data["_source"]["__chunks"] = res_chunks["_source"]["__chunks"]
combined_result.append(res_data)
else:
# This result just says that the doc was not found
# This result just says that the doc was not found ("found": False)
combined_result.append(dummy_res)

res["docs"] = combined_result
Expand Down
107 changes: 105 additions & 2 deletions tests/tensor_search/test_add_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,51 @@ def test_add_plain_id_field(self):
"title 1": "content 1",
"desc 2": "content 2. blah blah blah"
}

def test_add_documents_dupe_ids(self):
"""
Should only use the latest inserted ID. Make sure it doesn't get the first/middle one
"""

tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[
{
"_id": "3",
"title": "doc 3b"
},

], auto_refresh=True)

doc_3_solo = tensor_search.get_document_by_id(
config=self.config, index_name=self.index_name_1,
document_id="3", show_vectors=True)

tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[
{
"_id": "1",
"title": "doc 1"
},
{
"_id": "2",
"title": "doc 2",
},
{
"_id": "3",
"title": "doc 3a",
},
{
"_id": "3",
"title": "doc 3b"
},

], auto_refresh=True)

doc_3_duped = tensor_search.get_document_by_id(
config=self.config, index_name=self.index_name_1,
document_id="3", show_vectors=True)

self.assertEqual(doc_3_solo, doc_3_duped)


def test_update_docs_update_chunks(self):
"""Updating a doc needs to update the corresponding chunks"
Expand Down Expand Up @@ -259,16 +304,74 @@ def test_add_documents_validation(self):
{"_id": "to_fail_567", "some other obj": AssertionError}],
[{"_id": "to_fail_567", "blahblah": max}]
]
for update_mode in ('replace', 'update'):

# For update
for bad_doc_arg in bad_doc_args:
add_res = tensor_search.add_documents(
config=self.config, index_name=self.index_name_1,
docs=bad_doc_arg, auto_refresh=True, update_mode='update')
assert add_res['errors'] is True
assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')])
assert all(['result' in item
for item in add_res['items'] if item['_id'].startswith('to_pass')])

# For replace, check with use_existing_tensors True and False
for use_existing_tensors_flag in (True, False):
for bad_doc_arg in bad_doc_args:
add_res = tensor_search.add_documents(
config=self.config, index_name=self.index_name_1,
docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode)
docs=bad_doc_arg, auto_refresh=True, update_mode='replace', use_existing_tensors=use_existing_tensors_flag)
assert add_res['errors'] is True
assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')])
assert all(['result' in item
for item in add_res['items'] if item['_id'].startswith('to_pass')])


def test_add_documents_id_validation(self):
"""These bad docs should return errors"""
bad_doc_args = [
# Wrong data types for ID
# Tuple: (doc_list, number of docs that should succeed)
([{"_id": {}, "field_1": 1234}], 0),
([{"_id": dict(), "field_1": 1234}], 0),
([{"_id": [1, 2, 3], "field_1": 1234}], 0),
([{"_id": 4, "field_1": 1234}], 0),
([{"_id": None, "field_1": 1234}], 0),

([{"_id": "proper id", "field_1": 5678},
{"_id": ["bad", "id"], "field_1": "zzz"},
{"_id": "proper id 2", "field_1": 90}], 2)
]

# For update
for bad_doc_arg in bad_doc_args:
add_res = tensor_search.add_documents(
config=self.config, index_name=self.index_name_1,
docs=bad_doc_arg[0], auto_refresh=True, update_mode='update')

assert add_res['errors'] is True

succeeded_count = 0
for item in add_res['items']:
if 'result' in item:
succeeded_count += 1

assert succeeded_count == bad_doc_arg[1]

# For replace, check with use_existing_tensors True and False
for use_existing_tensors_flag in (True, False):
for bad_doc_arg in bad_doc_args:
add_res = tensor_search.add_documents(
config=self.config, index_name=self.index_name_1,
docs=bad_doc_arg[0], auto_refresh=True, update_mode='replace', use_existing_tensors=use_existing_tensors_flag)
assert add_res['errors'] is True
succeeded_count = 0
for item in add_res['items']:
if 'result' in item:
succeeded_count += 1

assert succeeded_count == bad_doc_arg[1]

def test_add_documents_list_non_tensor_validation(self):
"""This doc is valid but should return error because my_field is not marked non-tensor"""
bad_doc_args = [
Expand Down
Loading