Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions graphrag/config/models/vector_store_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ def _validate_url(self) -> None:
default=vector_store_defaults.database_name,
)

overwrite: bool = Field(
description="Overwrite the existing data.",
default=vector_store_defaults.overwrite,
)

embeddings_schema: dict[str, VectorStoreSchemaConfig] = {}

def _validate_embeddings_schema(self) -> None:
Expand Down
21 changes: 7 additions & 14 deletions graphrag/index/operations/embed_text/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,21 @@ async def embed_text(
msg = f"Column {id_column} not found in input dataframe with columns {input.columns}"
raise ValueError(msg)

total_rows = 0
for row in input[embed_column]:
if isinstance(row, list):
total_rows += len(row)
else:
total_rows += 1
vector_store.create_index()

i = 0
starting_index = 0
index = 0

all_results = []

num_total_batches = (input.shape[0] + batch_size - 1) // batch_size
while batch_size * i < input.shape[0]:
while batch_size * index < input.shape[0]:
logger.info(
"uploading text embeddings batch %d/%d of size %d to vector store",
i + 1,
index + 1,
num_total_batches,
batch_size,
)
batch = input.iloc[batch_size * i : batch_size * (i + 1)]
batch = input.iloc[batch_size * index : batch_size * (index + 1)]
texts: list[str] = batch[embed_column].tolist()
ids: list[str] = batch[id_column].tolist()
result = await run_embed_text(
Expand Down Expand Up @@ -93,8 +87,7 @@ async def embed_text(
)
documents.append(document)

vector_store.load_documents(documents, True)
starting_index += len(documents)
i += 1
vector_store.load_documents(documents)
index += 1

return all_results
97 changes: 48 additions & 49 deletions graphrag/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,58 +74,57 @@ def connect(self, **kwargs: Any) -> Any:
not_supported_error = "Azure AI Search expects `url`."
raise ValueError(not_supported_error)

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
def create_index(self) -> None:
"""Load documents into an Azure AI Search index."""
if overwrite:
if (
self.index_name is not None
and self.index_name in self.index_client.list_index_names()
):
self.index_client.delete_index(self.index_name)

# Configure vector search profile
vector_search = VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="HnswAlg",
parameters=HnswParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
)
],
profiles=[
VectorSearchProfile(
name=self.vector_search_profile_name,
algorithm_configuration_name="HnswAlg",
)
],
)
# Configure the index
index = SearchIndex(
name=self.index_name if self.index_name else "",
fields=[
SimpleField(
name=self.id_field,
type=SearchFieldDataType.String,
key=True,
),
SearchField(
name=self.vector_field,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
hidden=False, # DRIFT needs to return the vector for client-side similarity
vector_search_dimensions=self.vector_size,
vector_search_profile_name=self.vector_search_profile_name,
if (
self.index_name is not None
and self.index_name in self.index_client.list_index_names()
):
self.index_client.delete_index(self.index_name)

# Configure vector search profile
vector_search = VectorSearch(
algorithms=[
HnswAlgorithmConfiguration(
name="HnswAlg",
parameters=HnswParameters(
metric=VectorSearchAlgorithmMetric.COSINE
),
],
vector_search=vector_search,
)
self.index_client.create_or_update_index(
index,
)
)
],
profiles=[
VectorSearchProfile(
name=self.vector_search_profile_name,
algorithm_configuration_name="HnswAlg",
)
],
)
# Configure the index
index = SearchIndex(
name=self.index_name if self.index_name else "",
fields=[
SimpleField(
name=self.id_field,
type=SearchFieldDataType.String,
key=True,
),
SearchField(
name=self.vector_field,
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
hidden=False, # DRIFT needs to return the vector for client-side similarity
vector_search_dimensions=self.vector_size,
vector_search_profile_name=self.vector_search_profile_name,
),
],
vector_search=vector_search,
)
self.index_client.create_or_update_index(
index,
)

def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into an Azure AI Search index."""
batch = [
{
self.id_field: doc.id,
Expand Down
8 changes: 5 additions & 3 deletions graphrag/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ def connect(self, **kwargs: Any) -> None:
"""Connect to vector storage."""

@abstractmethod
def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
def create_index(self) -> None:
"""Create index."""

@abstractmethod
def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into the vector-store."""

@abstractmethod
Expand Down
11 changes: 5 additions & 6 deletions graphrag/vector_stores/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,18 @@ def _container_exists(self) -> bool:
]
return self._container_name in existing_container_names

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
def create_index(self) -> None:
"""Load documents into CosmosDB."""
# Create a CosmosDB container on overwrite
if overwrite:
self._delete_container()
self._create_container()
self._delete_container()
self._create_container()

if self._container_client is None:
msg = "Container client is not initialized."
raise ValueError(msg)

def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into CosmosDB."""
# Upload documents to CosmosDB
for doc in documents:
if doc.vector is not None:
Expand Down
47 changes: 26 additions & 21 deletions graphrag/vector_stores/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,33 @@ def connect(self, **kwargs: Any) -> Any:
if self.index_name and self.index_name in self.db_connection.table_names():
self.document_collection = self.db_connection.open_table(self.index_name)

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
def create_index(self) -> None:
"""Create index."""
dummy_vector = np.zeros(self.vector_size, dtype=np.float32)
flat_array = pa.array(dummy_vector, type=pa.float32())
vector_column = pa.FixedSizeListArray.from_arrays(flat_array, self.vector_size)

data = pa.table({
self.id_field: pa.array(["__DUMMY__"], type=pa.string()),
self.vector_field: vector_column,
})

self.document_collection = self.db_connection.create_table(
self.index_name if self.index_name else "",
data=data,
mode="overwrite",
schema=data.schema,
)

# Step 5: Create index now that schema exists
self.document_collection.create_index(
vector_column_name=self.vector_field, index_type="IVF_FLAT"
)

def load_documents(self, documents: list[VectorStoreDocument]) -> None:
"""Load documents into vector storage."""
self.document_collection.delete(f"{self.id_field} = '__DUMMY__'")

# Step 1: Prepare data columns manually
ids = []
vectors = []
Expand Down Expand Up @@ -68,31 +91,13 @@ def load_documents(
self.vector_field: vector_column,
})

# NOTE: If modifying the next section of code, ensure that the schema remains the same.
# The pyarrow format of the 'vector' field may change if the order of operations is changed
# and will break vector search.
if overwrite:
if data:
self.document_collection = self.db_connection.create_table(
self.index_name if self.index_name else "",
data=data,
mode="overwrite",
schema=data.schema,
)
else:
self.document_collection = self.db_connection.create_table(
self.index_name if self.index_name else "", mode="overwrite"
)
self.document_collection.create_index(
vector_column_name=self.vector_field, index_type="IVF_FLAT"
)
else:
# add data to existing table
self.document_collection = self.db_connection.open_table(
self.index_name if self.index_name else ""
)
if data:
self.document_collection.add(data)

def similarity_search_by_vector(
self, query_embedding: list[float] | np.ndarray, k: int = 10
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/vector_stores/test_azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def test_vector_store_operations(
"vector": [0.1, 0.2, 0.3, 0.4, 0.5],
}

vector_store.create_index()
vector_store.load_documents(sample_documents)
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called
Expand Down Expand Up @@ -188,6 +189,7 @@ async def test_vector_store_customization(
vector_store_custom.vector_field: [0.1, 0.2, 0.3, 0.4, 0.5],
}

vector_store_custom.create_index()
vector_store_custom.load_documents(sample_documents)
assert mock_index_client.create_or_update_index.called
assert mock_search_client.upload_documents.called
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/vector_stores/test_cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def test_vector_store_operations():
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
),
]

vector_store.create_index()
vector_store.load_documents(docs)

doc = vector_store.search_by_id("doc1")
Expand Down Expand Up @@ -84,6 +86,7 @@ def test_clear():
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
)

vector_store.create_index()
vector_store.load_documents([doc])
result = vector_store.search_by_id("test")
assert result.id == "test"
Expand Down Expand Up @@ -122,6 +125,8 @@ def test_vector_store_customization():
vector=[0.2, 0.3, 0.4, 0.5, 0.6],
),
]

vector_store.create_index()
vector_store.load_documents(docs)

doc = vector_store.search_by_id("doc1")
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/vector_stores/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def __init__(self, **kwargs):
def connect(self, **kwargs):
pass

def load_documents(self, documents, overwrite=True):
def create_index(self, **kwargs):
pass

def load_documents(self, documents):
pass

def similarity_search_by_vector(self, query_embedding, k=10, **kwargs):
Expand Down
14 changes: 10 additions & 4 deletions tests/integration/vector_stores/test_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_vector_store_operations(self, sample_documents):
)
)
vector_store.connect(db_uri=temp_dir)
vector_store.create_index()
vector_store.load_documents(sample_documents[:2])

if vector_store.index_name:
Expand All @@ -83,7 +84,8 @@ def test_vector_store_operations(self, sample_documents):
assert isinstance(results[0].score, float)

# Test append mode
vector_store.load_documents([sample_documents[2]], overwrite=False)
vector_store.create_index()
vector_store.load_documents([sample_documents[2]])
result = vector_store.search_by_id("3")
assert result.id == "3"

Expand Down Expand Up @@ -121,6 +123,7 @@ def test_empty_collection(self):
id="tmp",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
)
vector_store.create_index()
vector_store.load_documents([sample_doc])
vector_store.db_connection.open_table(
vector_store.index_name if vector_store.index_name else ""
Expand All @@ -137,7 +140,8 @@ def test_empty_collection(self):
id="1",
vector=[0.1, 0.2, 0.3, 0.4, 0.5],
)
vector_store.load_documents([doc], overwrite=False)
vector_store.create_index()
vector_store.load_documents([doc])

result = vector_store.search_by_id("1")
assert result.id == "1"
Expand All @@ -157,7 +161,7 @@ def test_filter_search(self, sample_documents_categories):
)

vector_store.connect(db_uri=temp_dir)

vector_store.create_index()
vector_store.load_documents(sample_documents_categories)

# Filter to include only documents about animals
Expand Down Expand Up @@ -186,6 +190,7 @@ def test_vector_store_customization(self, sample_documents):
),
)
vector_store.connect(db_uri=temp_dir)
vector_store.create_index()
vector_store.load_documents(sample_documents[:2])

if vector_store.index_name:
Expand All @@ -205,7 +210,8 @@ def test_vector_store_customization(self, sample_documents):
assert isinstance(results[0].score, float)

# Test append mode
vector_store.load_documents([sample_documents[2]], overwrite=False)
vector_store.create_index()
vector_store.load_documents([sample_documents[2]])
result = vector_store.search_by_id("3")
assert result.id == "3"

Expand Down
1 change: 0 additions & 1 deletion tests/unit/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def assert_vector_store_configs(
assert actual.api_key == expected.api_key
assert actual.audience == expected.audience
assert actual.container_name == expected.container_name
assert actual.overwrite == expected.overwrite
assert actual.database_name == expected.database_name


Expand Down
Loading
Loading