Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def __init__(
self.tokenizer = tokenizer or get_tokenizer()
self.embedding_vectorstore_key = embedding_vectorstore_key

def filter_by_entity_keys(self, entity_keys: list[int] | list[str]):
"""Filter entity text embeddings by entity keys."""
self.entity_text_embeddings.filter_by_id(entity_keys)

def build_context(
self,
query: str,
Expand Down
20 changes: 2 additions & 18 deletions graphrag/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,8 @@ def load_documents(
if len(batch) > 0:
self.db_connection.upload_documents(batch)

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by a list of ids."""
if include_ids is None or len(include_ids) == 0:
self.query_filter = None
# Returning to keep consistency with other methods, but not needed
return self.query_filter

# More info about odata filtering here: https://learn.microsoft.com/en-us/azure/search/search-query-odata-search-in-function
# search.in is faster that joined and/or conditions
id_filter = ",".join([f"{id!s}" for id in include_ids])
self.query_filter = f"search.in({self.id_field}, '{id_filter}', ',')"

# Returning to keep consistency with other methods, but not needed
# TODO: Refactor on a future PR
return self.query_filter

def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
self, query_embedding: list[float], k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
vectorized_query = VectorizedQuery(
Expand All @@ -193,7 +177,7 @@ def similarity_search_by_vector(
]

def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
self, text: str, text_embedder: TextEmbedder, k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform a text-based similarity search."""
query_embedding = text_embedder(text)
Expand Down
8 changes: 2 additions & 6 deletions graphrag/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,16 @@ def load_documents(

@abstractmethod
def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
self, query_embedding: list[float], k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform ANN search by vector."""

@abstractmethod
def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
self, text: str, text_embedder: TextEmbedder, k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform ANN search by text."""

@abstractmethod
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by id."""

@abstractmethod
def search_by_id(self, id: str) -> VectorStoreDocument:
"""Search for a document by id."""
18 changes: 2 additions & 16 deletions graphrag/vector_stores/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def load_documents(
self._container_client.upsert_item(doc_json)

def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
self, query_embedding: list[float], k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
if self._container_client is None:
Expand Down Expand Up @@ -241,7 +241,7 @@ def cosine_similarity(a, b):
]

def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
self, text: str, text_embedder: TextEmbedder, k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform a text-based similarity search."""
query_embedding = text_embedder(text)
Expand All @@ -251,20 +251,6 @@ def similarity_search_by_text(
)
return []

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by a list of ids."""
if include_ids is None or len(include_ids) == 0:
self.query_filter = None
else:
if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
else:
id_filter = ", ".join([str(id) for id in include_ids])
self.query_filter = (
f"SELECT * FROM c WHERE c.{self.id_field} IN ({id_filter})" # noqa: S608
)
return self.query_filter

def search_by_id(self, id: str) -> VectorStoreDocument:
"""Search for a document by id."""
if self._container_client is None:
Expand Down
18 changes: 2 additions & 16 deletions graphrag/vector_stores/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,8 @@ def load_documents(
if data:
self.document_collection.add(data)

def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
"""Build a query filter to filter documents by id."""
if len(include_ids) == 0:
self.query_filter = None
else:
if isinstance(include_ids[0], str):
id_filter = ", ".join([f"'{id}'" for id in include_ids])
self.query_filter = f"{self.id_field} in ({id_filter})"
else:
self.query_filter = (
f"{self.id_field} in ({', '.join([str(id) for id in include_ids])})"
)
return self.query_filter

def similarity_search_by_vector(
self, query_embedding: list[float] | np.ndarray, k: int = 10, **kwargs: Any
self, query_embedding: list[float] | np.ndarray, k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
if self.query_filter:
Expand Down Expand Up @@ -151,7 +137,7 @@ def similarity_search_by_vector(
]

def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
self, text: str, text_embedder: TextEmbedder, k: int = 10
) -> list[VectorStoreSearchResult]:
"""Perform a similarity search using a given input text."""
query_embedding = text_embedder(text)
Expand Down
9 changes: 0 additions & 9 deletions tests/integration/vector_stores/test_azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ async def test_vector_store_operations(
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called

filter_query = vector_store.filter_by_id(["doc1", "doc2"])
assert filter_query == "search.in(id, 'doc1,doc2', ',')"

vector_results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
Expand Down Expand Up @@ -215,12 +212,6 @@ async def test_vector_store_customization(
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called

filter_query = vector_store_custom.filter_by_id(["doc1", "doc2"])
assert (
filter_query
== f"search.in({vector_store_custom.id_field}, 'doc1,doc2', ',')"
)

vector_results = vector_store_custom.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
Expand Down
4 changes: 0 additions & 4 deletions tests/integration/vector_stores/test_cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def test_vector_store_operations():
]
vector_store.load_documents(docs)

vector_store.filter_by_id(["doc1"])

doc = vector_store.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
Expand Down Expand Up @@ -140,8 +138,6 @@ def test_vector_store_customization():
]
vector_store.load_documents(docs)

vector_store.filter_by_id(["doc1"])

doc = vector_store.search_by_id("doc1")
assert doc.id == "doc1"
assert doc.text == "This is document 1"
Expand Down
14 changes: 3 additions & 11 deletions tests/integration/vector_stores/test_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ def test_vector_store_operations(self, sample_documents):
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"

filter_query = vector_store.filter_by_id(["1"])
assert filter_query == "id in ('1')"

results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
Expand Down Expand Up @@ -186,16 +183,14 @@ def test_filter_search(self, sample_documents_categories):
vector_store.load_documents(sample_documents_categories)

# Filter to include only documents about animals
vector_store.filter_by_id(["1", "2"])
results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=3
)

# Should return at most 2 documents (the filtered ones)
assert len(results) <= 2
# Should return at most 3 documents (the filtered ones)
assert len(results) <= 3
ids = [result.document.id for result in results]
assert "3" not in ids
assert set(ids).issubset({"1", "2"})
assert set(ids).issubset({"1", "2", "3"})
finally:
shutil.rmtree(temp_dir)

Expand Down Expand Up @@ -230,9 +225,6 @@ def test_vector_store_customization(self, sample_documents):
assert np.allclose(doc.vector, [0.1, 0.2, 0.3, 0.4, 0.5])
assert doc.attributes["title"] == "Doc 1"

filter_query = vector_store.filter_by_id(["1"])
assert filter_query == f"{vector_store.id_field} in ('1')"

results = vector_store.similarity_search_by_vector(
[0.1, 0.2, 0.3, 0.4, 0.5], k=2
)
Expand Down
Loading