diff --git a/breaking-changes.md b/breaking-changes.md index b07aa6fca8..93430cbfe4 100644 --- a/breaking-changes.md +++ b/breaking-changes.md @@ -34,7 +34,7 @@ This is a summary of changes: - Collapsed the `vector_store` dict into a single root-level object. This is because we no longer support multi-search, and this dict required a lot of downstream complexity for that single use case. - Removed the `outputs` block that was also only used for multi-search. - Most workflows had an undocumented `strategy` config dict that allowed fine tuning of internal settings. These fine tunings are never used and had associated complexity, so we removed it. -- Vector store configuration now allows custom schema per embedded field. This overrides the need for the `container_name` prefix, which caused confusion anyway. Now, the default container name will simply be the embedded field name - if you need something custom, add the `embeddings_schema` block and populate as needed. +- Vector store configuration now allows custom schema per embedded field. This overrides the need for the `container_name` prefix, which caused confusion anyway. Now, the default container name will simply be the embedded field name - if you need something custom, add the `index_schema` block and populate as needed. - We previously supported the ability to embed any text field in the data model. However, we only ever use text_unit_text, entity_description, and community_full_content, so all others have been removed. - Removed the `umap` and `embed_graph` blocks which were only used to add x/y fields to the entities. This fixed a long-standing dependency issue with graspologic. If you need x/y positions, see the [visualization guide](https://microsoft.github.io/graphrag/visualization_guide/) for using gephi. - Removed file filtering from input document loading. This was essentially unused. diff --git a/docs/config/yaml.md b/docs/config/yaml.md index 1885db1fe5..0f67390833 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -172,9 +172,8 @@ Where to put all vectors for the system. Configured for lancedb by default. This - `url` **str** (only for AI Search) - AI Search endpoint - `api_key` **str** (optional - only for AI Search) - The AI Search api key to use. - `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used. -- `index_prefix` **str** - (optional) A prefix for the indexes you will create for embeddings. This stores all indexes (tables) for a given dataset ingest. - `database_name` **str** - (cosmosdb only) Name of the database. -- `embeddings_schema` **dict[str, dict[str, str]]** (optional) - Enables customization for each of your embeddings. +- `index_schema` **dict[str, dict[str, str]]** (optional) - Enables customization for each of your embeddings. - ``: - `index_name` **str**: (optional) - Name for the specific embedding index table. - `id_field` **str**: (optional) - Field name to be used as id. Default=`id` @@ -193,8 +192,7 @@ For example: vector_store: type: lancedb db_uri: output/lancedb - index_prefix: "christmas-carol" - embeddings_schema: + index_schema: text_unit_text: index_name: "text-unit-embeddings" id_field: "id_custom" diff --git a/docs/examples_notebooks/custom_vector_store.ipynb b/docs/examples_notebooks/custom_vector_store.ipynb index 60be6c1fb5..2e79c66d86 100644 --- a/docs/examples_notebooks/custom_vector_store.ipynb +++ b/docs/examples_notebooks/custom_vector_store.ipynb @@ -28,7 +28,7 @@ "\n", "### What You'll Learn\n", "\n", - "1. Understanding the `BaseVectorStore` interface\n", + "1. Understanding the `VectorStore` interface\n", "2. Implementing a custom vector store class\n", "3. Registering your vector store with the `VectorStoreFactory`\n", "4. Testing and validating your implementation\n", @@ -50,35 +50,13 @@ "```" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Any\n", - "\n", - "import numpy as np\n", - "import yaml\n", - "from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n", - "from graphrag.data_model.types import TextEmbedder\n", - "\n", - "# GraphRAG vector store components\n", - "from graphrag.vector_stores.base import (\n", - " BaseVectorStore,\n", - " VectorStoreDocument,\n", - " VectorStoreSearchResult,\n", - ")\n", - "from graphrag.vector_stores.factory import VectorStoreFactory" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Step 2: Understand the BaseVectorStore Interface\n", + "## Step 2: Understand the VectorStore Interface\n", "\n", - "Before using a custom vector store, let's examine the `BaseVectorStore` interface to understand what methods need to be implemented." + "Before using a custom vector store, let's examine the `VectorStore` interface to understand what methods need to be implemented." ] }, { @@ -87,18 +65,31 @@ "metadata": {}, "outputs": [], "source": [ - "# Let's inspect the BaseVectorStore class to understand the required methods\n", "import inspect\n", "\n", - "print(\"BaseVectorStore Abstract Methods:\")\n", - "print(\"=\" * 40)\n", + "# Let's inspect the VectorStore class to understand the required methods\n", + "from typing import Any\n", + "\n", + "import numpy as np\n", + "import yaml\n", + "from graphrag_vectors import (\n", + " IndexSchema,\n", + " TextEmbedder,\n", + " VectorStore,\n", + " VectorStoreDocument,\n", + " VectorStoreFactory,\n", + " VectorStoreSearchResult,\n", + ")\n", + "\n", + "print(\"VectorStore Abstract Methods:\")\n", + "print(\"=\" * 80)\n", "\n", "abstract_methods = []\n", - "for name, method in inspect.getmembers(BaseVectorStore, predicate=inspect.isfunction):\n", + "for name, method in inspect.getmembers(VectorStore, predicate=inspect.isfunction):\n", " if getattr(method, \"__isabstractmethod__\", False):\n", - " signature = inspect.signature(method)\n", - " abstract_methods.append(f\"• {name}{signature}\")\n", - " print(f\"• {name}{signature}\")\n", + " abstract_methods.append(name)\n", + " print(f\"\\n{name}:\")\n", + " print(f\" {inspect.signature(method)}\")\n", "\n", "print(f\"\\nTotal abstract methods to implement: {len(abstract_methods)}\")" ] @@ -112,7 +103,7 @@ "Now let's implement a simple in-memory vector store as an example. This vector store will:\n", "\n", "- Store documents and vectors in memory using Python data structures\n", - "- Support all required BaseVectorStore methods\n", + "- Support all required VectorStore methods\n", "\n", "**Note**: This is a simplified example for demonstration. Production vector stores would typically use optimized libraries like FAISS, more sophisticated indexing, and persistent storage." ] @@ -123,7 +114,7 @@ "metadata": {}, "outputs": [], "source": [ - "class SimpleInMemoryVectorStore(BaseVectorStore):\n", + "class SimpleInMemoryVectorStore(VectorStore):\n", " \"\"\"A simple in-memory vector store implementation for demonstration purposes.\n", "\n", " This vector store stores documents and their embeddings in memory and provides\n", @@ -147,112 +138,90 @@ " self.vectors: dict[str, np.ndarray] = {}\n", " self.connected = False\n", "\n", - " print(f\"🚀 SimpleInMemoryVectorStore initialized for index: {self.index_name}\")\n", - "\n", " def connect(self, **kwargs: Any) -> None:\n", - " \"\"\"Connect to the vector storage (no-op for in-memory store).\"\"\"\n", + " \"\"\"Connect to the vector store (simulated for in-memory store).\"\"\"\n", + " print(\"Connecting to in-memory vector store...\")\n", " self.connected = True\n", - " print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n", - "\n", - " def create_index(self) -> None:\n", - " \"\"\"Create index in the vector store (no-op for in-memory store).\"\"\"\n", - " self.documents.clear()\n", - " self.vectors.clear()\n", - "\n", - " print(f\"✅ Index '{self.index_name}' is ready in in-memory vector store\")\n", - "\n", - " def load_documents(self, documents: list[VectorStoreDocument]) -> None:\n", + " print(\"Connected successfully!\")\n", + "\n", + " def create_index(self, **kwargs: Any) -> None:\n", + " \"\"\"Create an index (simulated for in-memory store).\n", + "\n", + " In a real vector database, this would create the necessary data structures\n", + " and indexes for efficient vector search.\n", + " \"\"\"\n", + " print(f\"Creating index: {self.index_name}\")\n", + " # For in-memory store, we just ensure our storage dictionaries are ready\n", + " if not isinstance(self.documents, dict):\n", + " self.documents = {}\n", + " if not isinstance(self.vectors, dict):\n", + " self.vectors = {}\n", + " print(\"Index created successfully!\")\n", + "\n", + " def load_documents(\n", + " self, documents: list[VectorStoreDocument], overwrite: bool = False\n", + " ) -> None:\n", " \"\"\"Load documents into the vector store.\"\"\"\n", - " if not self.connected:\n", - " msg = \"Vector store not connected. Call connect() first.\"\n", - " raise RuntimeError(msg)\n", + " if overwrite:\n", + " print(\"Clearing existing documents...\")\n", + " self.documents.clear()\n", + " self.vectors.clear()\n", "\n", - " loaded_count = 0\n", + " print(f\"Loading {len(documents)} documents...\")\n", " for doc in documents:\n", - " if doc.vector is not None:\n", - " doc_id = str(doc.id)\n", - " self.documents[doc_id] = doc\n", - " self.vectors[doc_id] = np.array(doc.vector, dtype=np.float32)\n", - " loaded_count += 1\n", + " self.documents[doc.id] = doc\n", + " if doc.vector:\n", + " self.vectors[doc.id] = np.array(doc.vector)\n", "\n", - " print(f\"📚 Loaded {loaded_count} documents into vector store\")\n", - "\n", - " def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:\n", - " \"\"\"Calculate cosine similarity between two vectors.\"\"\"\n", - " # Normalize vectors\n", - " norm1 = np.linalg.norm(vec1)\n", - " norm2 = np.linalg.norm(vec2)\n", - "\n", - " if norm1 == 0 or norm2 == 0:\n", - " return 0.0\n", - "\n", - " return float(np.dot(vec1, vec2) / (norm1 * norm2))\n", + " print(f\"Successfully loaded {len(documents)} documents!\")\n", "\n", " def similarity_search_by_vector(\n", " self, query_embedding: list[float], k: int = 10, **kwargs: Any\n", " ) -> list[VectorStoreSearchResult]:\n", - " \"\"\"Perform similarity search using a query vector.\"\"\"\n", - " if not self.connected:\n", - " msg = \"Vector store not connected. Call connect() first.\"\n", - " raise RuntimeError(msg)\n", - "\n", + " \"\"\"Search for similar documents using a query vector.\"\"\"\n", " if not self.vectors:\n", " return []\n", "\n", - " query_vec = np.array(query_embedding, dtype=np.float32)\n", - " similarities = []\n", + " query_vector = np.array(query_embedding)\n", "\n", - " # Calculate similarity with all stored vectors\n", - " for doc_id, stored_vec in self.vectors.items():\n", - " similarity = self._cosine_similarity(query_vec, stored_vec)\n", + " # Calculate cosine similarity for all documents\n", + " similarities = []\n", + " for doc_id, doc_vector in self.vectors.items():\n", + " # Cosine similarity\n", + " similarity = np.dot(query_vector, doc_vector) / (\n", + " np.linalg.norm(query_vector) * np.linalg.norm(doc_vector)\n", + " )\n", " similarities.append((doc_id, similarity))\n", "\n", - " # Sort by similarity (descending) and take top k\n", + " # Sort by similarity (highest first) and take top k\n", " similarities.sort(key=lambda x: x[1], reverse=True)\n", - " top_k = similarities[:k]\n", + " top_results = similarities[:k]\n", "\n", - " # Create search results\n", + " # Convert to search results\n", " results = []\n", - " for doc_id, score in top_k:\n", - " document = self.documents[doc_id]\n", - " result = VectorStoreSearchResult(document=document, score=score)\n", - " results.append(result)\n", + " for doc_id, score in top_results:\n", + " doc = self.documents[doc_id]\n", + " results.append(VectorStoreSearchResult(document=doc, score=float(score)))\n", "\n", " return results\n", "\n", " def similarity_search_by_text(\n", - " self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any\n", + " self,\n", + " text: str,\n", + " text_embedder: TextEmbedder,\n", + " k: int = 10,\n", + " **kwargs: Any,\n", " ) -> list[VectorStoreSearchResult]:\n", - " \"\"\"Perform similarity search using text (which gets embedded first).\"\"\"\n", - " # Embed the text first\n", + " \"\"\"Search for similar documents using a text query.\"\"\"\n", + " # Embed the query text\n", " query_embedding = text_embedder(text)\n", "\n", - " # Use vector search with the embedding\n", + " # Use vector search\n", " return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n", "\n", - " def search_by_id(self, id: str) -> VectorStoreDocument:\n", - " \"\"\"Search for a document by id.\"\"\"\n", - " doc_id = str(id)\n", - " if doc_id not in self.documents:\n", - " msg = f\"Document with id '{id}' not found\"\n", - " raise KeyError(msg)\n", - "\n", - " return self.documents[doc_id]\n", - "\n", - " def get_stats(self) -> dict[str, Any]:\n", - " \"\"\"Get statistics about the vector store (custom method).\"\"\"\n", - " return {\n", - " \"index_name\": self.index_name,\n", - " \"document_count\": len(self.documents),\n", - " \"vector_count\": len(self.vectors),\n", - " \"connected\": self.connected,\n", - " \"vector_dimension\": len(next(iter(self.vectors.values())))\n", - " if self.vectors\n", - " else 0,\n", - " }\n", - "\n", - "\n", - "print(\"✅ SimpleInMemoryVectorStore class defined!\")" + " def search_by_id(self, id: str) -> VectorStoreDocument | None:\n", + " \"\"\"Retrieve a document by its ID.\"\"\"\n", + " return self.documents.get(id)" ] }, { @@ -337,17 +306,24 @@ "outputs": [], "source": [ "# Test creating vector store using the factory\n", - "schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n", + "schema = IndexSchema(index_name=\"test_collection\")\n", "\n", "# Create vector store instance using factory\n", "vector_store = VectorStoreFactory().create(\n", - " CUSTOM_VECTOR_STORE_TYPE, {\"vector_store_schema_config\": schema}\n", + " CUSTOM_VECTOR_STORE_TYPE, {\"index_schema\": schema}\n", ")\n", "\n", "print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n", "print(f\"📊 Initial stats: {vector_store.get_stats()}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -464,11 +440,11 @@ " print(\"🚀 Simulating GraphRAG pipeline with custom vector store...\\n\")\n", "\n", " # 1. GraphRAG creates vector store using factory\n", - " schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n", + " schema = IndexSchema(index_name=\"graphrag_entities\")\n", "\n", " store = VectorStoreFactory().create(\n", " CUSTOM_VECTOR_STORE_TYPE,\n", - " {\"vector_store_schema_config\": schema, \"similarity_threshold\": 0.3},\n", + " {\"index_schema\": schema, \"similarity_threshold\": 0.3},\n", " )\n", " store.connect()\n", " store.create_index()\n", @@ -529,7 +505,7 @@ " print(\"Test 1: Basic functionality\")\n", " store = VectorStoreFactory().create(\n", " CUSTOM_VECTOR_STORE_TYPE,\n", - " {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test\")},\n", + " {\"index_schema\": IndexSchema(index_name=\"test\")},\n", " )\n", " store.connect()\n", " store.create_index()\n", @@ -572,7 +548,7 @@ " print(\"\\nTest 5: Error handling\")\n", " disconnected_store = VectorStoreFactory().create(\n", " CUSTOM_VECTOR_STORE_TYPE,\n", - " {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test2\")},\n", + " {\"index_schema\": IndexSchema(index_name=\"test2\")},\n", " )\n", "\n", " try:\n", @@ -611,16 +587,24 @@ "- ✅ **Configuration Examples**: Learned how to configure GraphRAG to use your vector store\n", "\n", "### Key Takeaways\n", - "1. **Interface Compliance**: Always implement all methods from `BaseVectorStore`\n", + "1. **Interface Compliance**: Always implement all methods from `VectorStore`\n", "2. **Factory Pattern**: Use `VectorStoreFactory.register()` to make your vector store available\n", - "3. **Configuration**: Vector stores are configured in GraphRAG settings files\n", - "4. **Testing**: Thoroughly test all functionality before deploying\n", - "\n", - "### Next Steps\n", - "Check out the API Overview notebook to learn how to index and query data via the graphrag API.\n", + "3. **Testing**: Validate your implementation thoroughly before production use\n", + "4. **Configuration**: Use YAML or environment variables for flexible configuration\n", + "\n", + "### Production Considerations\n", + "For production use, consider:\n", + "- **Persistence**: Add data persistence mechanisms\n", + "- **Scalability**: Use optimized vector search libraries (FAISS, HNSW)\n", + "- **Error Handling**: Implement robust error handling and logging\n", + "- **Performance**: Add caching, batching, and connection pooling\n", + "- **Security**: Implement authentication and authorization\n", + "- **Monitoring**: Add metrics and health checks\n", "\n", "### Resources\n", "- [GraphRAG Documentation](https://microsoft.github.io/graphrag/)\n", + "- [Vector Store Examples](https://github.com/microsoft/graphrag/tree/main/packages/graphrag-vectors)\n", + "- [GraphRAG GitHub Repository](https://github.com/microsoft/graphrag)\n", "\n", "Happy building! 🚀" ] @@ -628,7 +612,7 @@ ], "metadata": { "kernelspec": { - "display_name": "graphrag", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/docs/examples_notebooks/drift_search.ipynb b/docs/examples_notebooks/drift_search.ipynb index 81cd193eec..8d53c7d9cc 100644 --- a/docs/examples_notebooks/drift_search.ipynb +++ b/docs/examples_notebooks/drift_search.ipynb @@ -22,7 +22,6 @@ "from graphrag.config.enums import ModelType\n", "from graphrag.config.models.drift_search_config import DRIFTSearchConfig\n", "from graphrag.config.models.language_model_config import LanguageModelConfig\n", - "from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n", "from graphrag.language_model.manager import ModelManager\n", "from graphrag.query.indexer_adapters import (\n", " read_indexer_entities,\n", @@ -36,7 +35,7 @@ ")\n", "from graphrag.query.structured_search.drift_search.search import DRIFTSearch\n", "from graphrag.tokenizer.get_tokenizer import get_tokenizer\n", - "from graphrag.vector_stores.lancedb import LanceDBVectorStore\n", + "from graphrag_vectors.lancedb import LanceDBVectorStore\n", "\n", "INPUT_DIR = \"./inputs/operation dulce\"\n", "LANCEDB_URI = f\"{INPUT_DIR}/lancedb\"\n", @@ -61,16 +60,16 @@ "# load description embeddings to an in-memory lancedb vectorstore\n", "# to connect to a remote db, specify url and port values.\n", "description_embedding_store = LanceDBVectorStore(\n", - " vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"entity_description\"),\n", + " db_uri=LANCEDB_URI,\n", + " index_name=\"entity_description\",\n", ")\n", - "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", + "description_embedding_store.connect()\n", "\n", "full_content_embedding_store = LanceDBVectorStore(\n", - " vector_store_schema_config=VectorStoreSchemaConfig(\n", - " index_name=\"community_full_content\"\n", - " )\n", + " db_uri=LANCEDB_URI,\n", + " index_name=\"community_full_content\",\n", ")\n", - "full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n", + "full_content_embedding_store.connect()\n", "\n", "print(f\"Entity count: {len(entity_df)}\")\n", "entity_df.head()\n", diff --git a/docs/examples_notebooks/local_search.ipynb b/docs/examples_notebooks/local_search.ipynb index 4a5497a824..f7f0c5a54b 100644 --- a/docs/examples_notebooks/local_search.ipynb +++ b/docs/examples_notebooks/local_search.ipynb @@ -19,7 +19,6 @@ "import os\n", "\n", "import pandas as pd\n", - "from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig\n", "from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey\n", "from graphrag.query.indexer_adapters import (\n", " read_indexer_covariates,\n", @@ -33,7 +32,7 @@ " LocalSearchMixedContext,\n", ")\n", "from graphrag.query.structured_search.local_search.search import LocalSearch\n", - "from graphrag.vector_stores.lancedb import LanceDBVectorStore" + "from graphrag_vectors import IndexSchema, LanceDBVectorStore" ] }, { @@ -101,9 +100,7 @@ "# load description embeddings to an in-memory lancedb vectorstore\n", "# to connect to a remote db, specify url and port values.\n", "description_embedding_store = LanceDBVectorStore(\n", - " vector_store_schema_config=VectorStoreSchemaConfig(\n", - " index_name=\"default-entity-description\"\n", - " )\n", + " index_schema=IndexSchema(index_name=\"default-entity-description\")\n", ")\n", "description_embedding_store.connect(db_uri=LANCEDB_URI)\n", "\n", diff --git a/packages/graphrag-vectors/README.md b/packages/graphrag-vectors/README.md new file mode 100644 index 0000000000..1966c442e8 --- /dev/null +++ b/packages/graphrag-vectors/README.md @@ -0,0 +1,109 @@ +# GraphRAG Vectors + +Vector store implementations for GraphRAG. + +## Basic Usage + +### Using the utility function (recommended) + +```python +from graphrag_vectors import ( + create_vector_store, + VectorStoreType, + IndexSchema, +) + +# Create a vector store using the convenience function +store_config = VectorStoreConfig( + type="lancedb", + db_uri="lance" +) + +schema_config = IndexSchema( + index_name="my_index", + vector_size=1536, +) + +vector_store = create_vector_store( + config=store_config + index_schema=schema_config, +) + +vector_store.connect() +vector_store.create_index() +``` + +### Using the factory directly + +```python +from graphrag_vectors import ( + VectorStoreFactory, + vector_store_factory, + VectorStoreType, + IndexSchema, +) + +# Create a vector store using the factory +schema_config = IndexSchema( + index_name="my_index", + vector_size=1536, +) + +vector_store = vector_store_factory.create( + VectorStoreType.LanceDB, + { + "index_schema": schema_config, + "db_uri": "./lancedb" + } +) + +vector_store.connect() +vector_store.create_index() +``` + +## Supported Vector Stores + +- **LanceDB**: Local vector database +- **Azure AI Search**: Azure's managed search service with vector capabilities +- **Azure Cosmos DB**: Azure's NoSQL database with vector search support + +## Custom Vector Store + +You can register custom vector store implementations: + +```python +from graphrag_vectors import VectorStore, register_vector_store, create_vector_store + +class MyCustomVectorStore(VectorStore): + def __init__(self, my_param): + self.my_param = my_param + + def connect(self): + # Implementation + pass + + def create_index(self): + # Implementation + pass + + # ... implement other required methods + +# Register your custom implementation +register_vector_store("my_custom_store", MyCustomVectorStore) + +# Use your custom vector store +config = VectorStoreConfig( + type="my_custom_store", + my_param="something" +) +custom_store = create_vector_store( + config=config, + index_schema=schema_config, +) +``` + +## Configuration + +Vector stores are configured using: +- `VectorStoreConfig`: baseline parameters for the store +- `IndexSchema`: Schema configuration for the specific index to create/connect to (index name, field names, vector size) diff --git a/packages/graphrag-vectors/graphrag_vectors/__init__.py b/packages/graphrag-vectors/graphrag_vectors/__init__.py new file mode 100644 index 0000000000..915d1f0cd1 --- /dev/null +++ b/packages/graphrag-vectors/graphrag_vectors/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""GraphRAG vector store implementations.""" + +from graphrag_vectors.index_schema import IndexSchema +from graphrag_vectors.types import TextEmbedder +from graphrag_vectors.vector_store import ( + VectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) +from graphrag_vectors.vector_store_config import VectorStoreConfig +from graphrag_vectors.vector_store_factory import ( + VectorStoreFactory, + create_vector_store, + register_vector_store, + vector_store_factory, +) +from graphrag_vectors.vector_store_type import VectorStoreType + +__all__ = [ + "IndexSchema", + "TextEmbedder", + "VectorStore", + "VectorStoreConfig", + "VectorStoreDocument", + "VectorStoreFactory", + "VectorStoreSearchResult", + "VectorStoreType", + "create_vector_store", + "register_vector_store", + "vector_store_factory", +] diff --git a/packages/graphrag/graphrag/vector_stores/azure_ai_search.py b/packages/graphrag-vectors/graphrag_vectors/azure_ai_search.py similarity index 71% rename from packages/graphrag/graphrag/vector_stores/azure_ai_search.py rename to packages/graphrag-vectors/graphrag_vectors/azure_ai_search.py index e6193e09d9..fcbb5e48ae 100644 --- a/packages/graphrag/graphrag/vector_stores/azure_ai_search.py +++ b/packages/graphrag-vectors/graphrag_vectors/azure_ai_search.py @@ -22,49 +22,59 @@ ) from azure.search.documents.models import VectorizedQuery -from graphrag.data_model.types import TextEmbedder -from graphrag.vector_stores.base import ( - BaseVectorStore, +from graphrag_vectors.vector_store import ( + VectorStore, VectorStoreDocument, VectorStoreSearchResult, ) -class AzureAISearchVectorStore(BaseVectorStore): +class AzureAISearchVectorStore(VectorStore): """Azure AI Search vector storage implementation.""" index_client: SearchIndexClient - def connect(self, **kwargs: Any) -> Any: + def __init__( + self, + url: str, + api_key: str | None = None, + audience: str | None = None, + vector_search_profile_name: str = "vectorSearchProfile", + **kwargs: Any, + ): + super().__init__(**kwargs) + if not url: + msg = "url must be provided for Azure AI Search." + raise ValueError(msg) + self.url = url + self.api_key = api_key + self.audience = audience + self.vector_search_profile_name = vector_search_profile_name + + def connect(self) -> Any: """Connect to AI search vector storage.""" - url = kwargs["url"] - api_key = kwargs.get("api_key") - audience = kwargs.get("audience") - - self.vector_search_profile_name = kwargs.get( - "vector_search_profile_name", "vectorSearchProfile" + audience_arg = ( + {"audience": self.audience} if self.audience and not self.api_key else {} + ) + self.db_connection = SearchClient( + endpoint=self.url, + index_name=self.index_name, + credential=( + AzureKeyCredential(self.api_key) + if self.api_key + else DefaultAzureCredential() + ), + **audience_arg, + ) + self.index_client = SearchIndexClient( + endpoint=self.url, + credential=( + AzureKeyCredential(self.api_key) + if self.api_key + else DefaultAzureCredential() + ), + **audience_arg, ) - - if url: - audience_arg = {"audience": audience} if audience and not api_key else {} - self.db_connection = SearchClient( - endpoint=url, - index_name=self.index_name if self.index_name else "", - credential=( - AzureKeyCredential(api_key) if api_key else DefaultAzureCredential() - ), - **audience_arg, - ) - self.index_client = SearchIndexClient( - endpoint=url, - credential=( - AzureKeyCredential(api_key) if api_key else DefaultAzureCredential() - ), - **audience_arg, - ) - else: - not_supported_error = "Azure AI Search expects `url`." - raise ValueError(not_supported_error) def create_index(self) -> None: """Load documents into an Azure AI Search index.""" @@ -93,7 +103,7 @@ def create_index(self) -> None: ) # Configure the index index = SearchIndex( - name=self.index_name if self.index_name else "", + name=self.index_name, fields=[ SimpleField( name=self.id_field, @@ -154,17 +164,6 @@ def similarity_search_by_vector( for doc in response ] - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10 - ) -> list[VectorStoreSearchResult]: - """Perform a text-based similarity search.""" - query_embedding = text_embedder(text) - if query_embedding: - return self.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - return [] - def search_by_id(self, id: str) -> VectorStoreDocument: """Search for a document by id.""" response = self.db_connection.get_document(id) diff --git a/packages/graphrag/graphrag/vector_stores/cosmosdb.py b/packages/graphrag-vectors/graphrag_vectors/cosmosdb.py similarity index 81% rename from packages/graphrag/graphrag/vector_stores/cosmosdb.py rename to packages/graphrag-vectors/graphrag_vectors/cosmosdb.py index 7ad06950c2..fcff47e782 100644 --- a/packages/graphrag/graphrag/vector_stores/cosmosdb.py +++ b/packages/graphrag-vectors/graphrag_vectors/cosmosdb.py @@ -10,67 +10,71 @@ from azure.cosmos.partition_key import PartitionKey from azure.identity import DefaultAzureCredential -from graphrag.data_model.types import TextEmbedder -from graphrag.vector_stores.base import ( - BaseVectorStore, +from graphrag_vectors.vector_store import ( + VectorStore, VectorStoreDocument, VectorStoreSearchResult, ) -class CosmosDBVectorStore(BaseVectorStore): +class CosmosDBVectorStore(VectorStore): """Azure CosmosDB vector storage implementation.""" _cosmos_client: CosmosClient _database_client: DatabaseProxy _container_client: ContainerProxy - def connect(self, **kwargs: Any) -> Any: + def __init__( + self, + database_name: str, + connection_string: str | None = None, + url: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + if self.id_field != "id": + msg = "CosmosDB requires the id_field to be 'id'." + raise ValueError(msg) + if not connection_string and not url: + msg = "Either connection_string or url must be provided for CosmosDB." + raise ValueError(msg) + + self.database_name = database_name + self.connection_string = connection_string + self.url = url + + def connect(self) -> Any: """Connect to CosmosDB vector storage.""" - connection_string = kwargs.get("connection_string") - if connection_string: - self._cosmos_client = CosmosClient.from_connection_string(connection_string) + if self.connection_string: + self._cosmos_client = CosmosClient.from_connection_string( + self.connection_string + ) else: - url = kwargs.get("url") - if not url: - msg = "Either connection_string or url must be provided." - raise ValueError(msg) self._cosmos_client = CosmosClient( - url=url, credential=DefaultAzureCredential() + url=self.url, credential=DefaultAzureCredential() ) - database_name = kwargs.get("database_name") - if database_name is None: - msg = "Database name must be provided." - raise ValueError(msg) - self._database_name = database_name - if self.index_name is None: - msg = "Index name is empty or not provided." - raise ValueError(msg) - self._container_name = self.index_name - - self.vector_size = self.vector_size self._create_database() self._create_container() def _create_database(self) -> None: """Create the database if it doesn't exist.""" - self._cosmos_client.create_database_if_not_exists(id=self._database_name) + self._cosmos_client.create_database_if_not_exists(id=self.database_name) self._database_client = self._cosmos_client.get_database_client( - self._database_name + self.database_name ) def _delete_database(self) -> None: """Delete the database if it exists.""" if self._database_exists(): - self._cosmos_client.delete_database(self._database_name) + self._cosmos_client.delete_database(self.database_name) def _database_exists(self) -> bool: """Check if the database exists.""" existing_database_names = [ database["id"] for database in self._cosmos_client.list_databases() ] - return self._database_name in existing_database_names + return self.database_name in existing_database_names def _create_container(self) -> None: """Create the container if it doesn't exist.""" @@ -108,7 +112,7 @@ def _create_container(self) -> None: # Create the container and container client self._database_client.create_container_if_not_exists( - id=self._container_name, + id=self.index_name, partition_key=partition_key, indexing_policy=indexing_policy, vector_embedding_policy=vector_embedding_policy, @@ -119,27 +123,27 @@ def _create_container(self) -> None: # Create the container with compatible indexing policy self._database_client.create_container_if_not_exists( - id=self._container_name, + id=self.index_name, partition_key=partition_key, indexing_policy=indexing_policy, vector_embedding_policy=vector_embedding_policy, ) self._container_client = self._database_client.get_container_client( - self._container_name + self.index_name ) def _delete_container(self) -> None: """Delete the vector store container in the database if it exists.""" if self._container_exists(): - self._database_client.delete_container(self._container_name) + self._database_client.delete_container(self.index_name) def _container_exists(self) -> bool: """Check if the container name exists in the database.""" existing_container_names = [ container["id"] for container in self._database_client.list_containers() ] - return self._container_name in existing_container_names + return self.index_name in existing_container_names def create_index(self) -> None: """Load documents into CosmosDB.""" @@ -222,17 +226,6 @@ def cosine_similarity(a, b): for item in items ] - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10 - ) -> list[VectorStoreSearchResult]: - """Perform a text-based similarity search.""" - query_embedding = text_embedder(text) - if query_embedding: - return self.similarity_search_by_vector( - query_embedding=query_embedding, k=k - ) - return [] - def search_by_id(self, id: str) -> VectorStoreDocument: """Search for a document by id.""" if self._container_client is None: diff --git a/packages/graphrag/graphrag/config/models/vector_store_schema_config.py b/packages/graphrag-vectors/graphrag_vectors/index_schema.py similarity index 90% rename from packages/graphrag/graphrag/config/models/vector_store_schema_config.py rename to packages/graphrag-vectors/graphrag_vectors/index_schema.py index ccb91b3cbb..2eacc3f0c0 100644 --- a/packages/graphrag/graphrag/config/models/vector_store_schema_config.py +++ b/packages/graphrag-vectors/graphrag_vectors/index_schema.py @@ -17,9 +17,13 @@ def is_valid_field_name(field: str) -> bool: return bool(VALID_IDENTIFIER_REGEX.match(field)) -class VectorStoreSchemaConfig(BaseModel): +class IndexSchema(BaseModel): """The default configuration section for Vector Store Schema.""" + index_name: str = Field( + description="The index name to use.", default="vector_index" + ) + id_field: str = Field( description="The ID field to use.", default="id", @@ -35,8 +39,6 @@ class VectorStoreSchemaConfig(BaseModel): default=DEFAULT_VECTOR_SIZE, ) - index_name: str | None = Field(description="The index name to use.", default=None) - def _validate_schema(self) -> None: """Validate the schema.""" for field in [ diff --git a/packages/graphrag/graphrag/vector_stores/lancedb.py b/packages/graphrag-vectors/graphrag_vectors/lancedb.py similarity index 74% rename from packages/graphrag/graphrag/vector_stores/lancedb.py rename to packages/graphrag-vectors/graphrag_vectors/lancedb.py index bddc83cab2..724885d34b 100644 --- a/packages/graphrag/graphrag/vector_stores/lancedb.py +++ b/packages/graphrag-vectors/graphrag_vectors/lancedb.py @@ -9,20 +9,23 @@ import numpy as np import pyarrow as pa -from graphrag.data_model.types import TextEmbedder -from graphrag.vector_stores.base import ( - BaseVectorStore, +from graphrag_vectors.vector_store import ( + VectorStore, VectorStoreDocument, VectorStoreSearchResult, ) -class LanceDBVectorStore(BaseVectorStore): +class LanceDBVectorStore(VectorStore): """LanceDB vector storage implementation.""" - def connect(self, **kwargs: Any) -> Any: + def __init__(self, db_uri: str = "lancedb", **kwargs: Any): + super().__init__(**kwargs) + self.db_uri = db_uri + + def connect(self) -> Any: """Connect to the vector storage.""" - self.db_connection = lancedb.connect(kwargs["db_uri"]) + self.db_connection = lancedb.connect(self.db_uri) if self.index_name and self.index_name in self.db_connection.table_names(): self.document_collection = self.db_connection.open_table(self.index_name) @@ -90,25 +93,15 @@ def similarity_search_by_vector( self, query_embedding: list[float] | np.ndarray, k: int = 10 ) -> list[VectorStoreSearchResult]: """Perform a vector-based similarity search.""" - if self.query_filter: - docs = ( - self.document_collection.search( - query=query_embedding, vector_column_name=self.vector_field - ) - .where(self.query_filter, prefilter=True) - .limit(k) - .to_list() - ) - else: - query_embedding = np.array(query_embedding, dtype=np.float32) - - docs = ( - self.document_collection.search( - query=query_embedding, vector_column_name=self.vector_field - ) - .limit(k) - .to_list() + query_embedding = np.array(query_embedding, dtype=np.float32) + + docs = ( + self.document_collection.search( + query=query_embedding, vector_column_name=self.vector_field ) + .limit(k) + .to_list() + ) return [ VectorStoreSearchResult( document=VectorStoreDocument( @@ -120,15 +113,6 @@ def similarity_search_by_vector( for doc in docs ] - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10 - ) -> list[VectorStoreSearchResult]: - """Perform a similarity search using a given input text.""" - query_embedding = text_embedder(text) - if query_embedding: - return self.similarity_search_by_vector(query_embedding, k) - return [] - def search_by_id(self, id: str) -> VectorStoreDocument: """Search for a document by id.""" doc = ( diff --git a/packages/graphrag-vectors/graphrag_vectors/types.py b/packages/graphrag-vectors/graphrag_vectors/types.py new file mode 100644 index 0000000000..63b032486c --- /dev/null +++ b/packages/graphrag-vectors/graphrag_vectors/types.py @@ -0,0 +1,8 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Common types for vector stores.""" + +from collections.abc import Callable + +TextEmbedder = Callable[[str], list[float]] diff --git a/packages/graphrag/graphrag/vector_stores/base.py b/packages/graphrag-vectors/graphrag_vectors/vector_store.py similarity index 63% rename from packages/graphrag/graphrag/vector_stores/base.py rename to packages/graphrag-vectors/graphrag_vectors/vector_store.py index eda7e189b6..de19bb5a4c 100644 --- a/packages/graphrag/graphrag/vector_stores/base.py +++ b/packages/graphrag-vectors/graphrag_vectors/vector_store.py @@ -7,8 +7,7 @@ from dataclasses import dataclass from typing import Any -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.data_model.types import TextEmbedder +from graphrag_vectors.types import TextEmbedder @dataclass @@ -32,29 +31,24 @@ class VectorStoreSearchResult: """Similarity score between -1 and 1. Higher is more similar.""" -class BaseVectorStore(ABC): +class VectorStore(ABC): """The base class for vector storage data-access classes.""" def __init__( self, - vector_store_schema_config: VectorStoreSchemaConfig, - db_connection: Any | None = None, - document_collection: Any | None = None, - query_filter: Any | None = None, + index_name: str = "vector_index", + id_field: str = "id", + vector_field: str = "vector", + vector_size: int = 3072, **kwargs: Any, ): - self.db_connection = db_connection - self.document_collection = document_collection - self.query_filter = query_filter - self.kwargs = kwargs - - self.index_name = vector_store_schema_config.index_name - self.id_field = vector_store_schema_config.id_field - self.vector_field = vector_store_schema_config.vector_field - self.vector_size = vector_store_schema_config.vector_size + self.index_name = index_name + self.id_field = id_field + self.vector_field = vector_field + self.vector_size = vector_size @abstractmethod - def connect(self, **kwargs: Any) -> None: + def connect(self) -> None: """Connect to vector storage.""" @abstractmethod @@ -71,11 +65,16 @@ def similarity_search_by_vector( ) -> list[VectorStoreSearchResult]: """Perform ANN search by vector.""" - @abstractmethod def similarity_search_by_text( self, text: str, text_embedder: TextEmbedder, k: int = 10 ) -> list[VectorStoreSearchResult]: - """Perform ANN search by text.""" + """Perform a text-based similarity search.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector( + query_embedding=query_embedding, k=k + ) + return [] @abstractmethod def search_by_id(self, id: str) -> VectorStoreDocument: diff --git a/packages/graphrag-vectors/graphrag_vectors/vector_store_config.py b/packages/graphrag-vectors/graphrag_vectors/vector_store_config.py new file mode 100644 index 0000000000..17f70af86d --- /dev/null +++ b/packages/graphrag-vectors/graphrag_vectors/vector_store_config.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pydantic import BaseModel, ConfigDict, Field + +from graphrag_vectors.index_schema import IndexSchema +from graphrag_vectors.vector_store_type import VectorStoreType + + +class VectorStoreConfig(BaseModel): + """The default configuration section for Vector Store.""" + + model_config = ConfigDict(extra="allow") + """Allow extra fields to support custom vector implementations.""" + + type: str = Field( + description="The vector store type to use.", + default=VectorStoreType.LanceDB, + ) + + db_uri: str | None = Field( + description="The database URI to use (only used by lancedb for built-in stores).", + default=None, + ) + + url: str | None = Field( + description="The database URL when type == azure_ai_search or cosmosdb.", + default=None, + ) + + api_key: str | None = Field( + description="The database API key when type == azure_ai_search.", + default=None, + ) + + audience: str | None = Field( + description="The database audience when type == azure_ai_search.", + default=None, + ) + + connection_string: str | None = Field( + description="The connection string when type == cosmosdb.", + default=None, + ) + + database_name: str | None = Field( + description="The database name to use when type == cosmosdb.", + default=None, + ) + + index_schema: dict[str, IndexSchema] = {} diff --git a/packages/graphrag-vectors/graphrag_vectors/vector_store_factory.py b/packages/graphrag-vectors/graphrag_vectors/vector_store_factory.py new file mode 100644 index 0000000000..6d94fa63a9 --- /dev/null +++ b/packages/graphrag-vectors/graphrag_vectors/vector_store_factory.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Factory functions for creating a vector store.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graphrag_common.factory import Factory, ServiceScope + +from graphrag_vectors.vector_store import VectorStore +from graphrag_vectors.vector_store_type import VectorStoreType + +if TYPE_CHECKING: + from collections.abc import Callable + + from graphrag_vectors.index_schema import IndexSchema + from graphrag_vectors.vector_store_config import VectorStoreConfig + + +class VectorStoreFactory(Factory[VectorStore]): + """A factory for vector stores. + + Includes a method for users to register a custom vector store implementation. + + Configuration arguments are passed to each vector store implementation as kwargs + for individual enforcement of required/optional arguments. + """ + + +vector_store_factory = VectorStoreFactory() + + +def register_vector_store( + vector_store_type: str, + vector_store_initializer: Callable[..., VectorStore], + scope: ServiceScope = "transient", +) -> None: + """Register a custom vector store implementation. + + Args + ---- + - vector_store_type: str + The vector store id to register. + - vector_store_initializer: Callable[..., VectorStore] + The vector store initializer to register. + - scope: ServiceScope + The service scope for the vector store (default: "transient"). + """ + vector_store_factory.register(vector_store_type, vector_store_initializer, scope) + + +def create_vector_store( + config: VectorStoreConfig, index_schema: IndexSchema +) -> VectorStore: + """Create a vector store implementation based on the given type and configuration. + + Args + ---- + - config: VectorStoreConfig + The base vector store configuration. + - index_schema: IndexSchema + The index schema configuration for the vector store instance - i.e., for the specific table we are reading/writing. + + Returns + ------- + VectorStore + The created vector store implementation. + """ + strategy = config.type + + # Lazy load built-in implementations + if strategy not in vector_store_factory: + match strategy: + case VectorStoreType.LanceDB: + from graphrag_vectors.lancedb import LanceDBVectorStore + + register_vector_store(VectorStoreType.LanceDB, LanceDBVectorStore) + case VectorStoreType.AzureAISearch: + from graphrag_vectors.azure_ai_search import AzureAISearchVectorStore + + register_vector_store( + VectorStoreType.AzureAISearch, AzureAISearchVectorStore + ) + case VectorStoreType.CosmosDB: + from graphrag_vectors.cosmosdb import CosmosDBVectorStore + + register_vector_store(VectorStoreType.CosmosDB, CosmosDBVectorStore) + case _: + msg = f"Vector store type '{strategy}' is not registered in the VectorStoreFactory. Registered types: {', '.join(vector_store_factory.keys())}." + raise ValueError(msg) + + # collapse the base config and specific index config into a single dict for the initializer + config_model = config.model_dump() + index_model = index_schema.model_dump() + return vector_store_factory.create( + strategy, init_args={**config_model, **index_model} + ) diff --git a/packages/graphrag-vectors/graphrag_vectors/vector_store_type.py b/packages/graphrag-vectors/graphrag_vectors/vector_store_type.py new file mode 100644 index 0000000000..86b60bf2fc --- /dev/null +++ b/packages/graphrag-vectors/graphrag_vectors/vector_store_type.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Vector store type enum.""" + +from enum import StrEnum + + +class VectorStoreType(StrEnum): + """The supported vector store types.""" + + LanceDB = "lancedb" + AzureAISearch = "azure_ai_search" + CosmosDB = "cosmosdb" diff --git a/packages/graphrag-vectors/pyproject.toml b/packages/graphrag-vectors/pyproject.toml new file mode 100644 index 0000000000..e851f244d2 --- /dev/null +++ b/packages/graphrag-vectors/pyproject.toml @@ -0,0 +1,49 @@ +[project] +name = "graphrag-vectors" +version = "2.7.0" +description = "GraphRAG vector store package." +authors = [ + {name = "Alonso Guevara Fernández", email = "alonsog@microsoft.com"}, + {name = "Andrés Morales Esquivel", email = "andresmor@microsoft.com"}, + {name = "Chris Trevino", email = "chtrevin@microsoft.com"}, + {name = "David Tittsworth", email = "datittsw@microsoft.com"}, + {name = "Dayenne de Souza", email = "ddesouza@microsoft.com"}, + {name = "Derek Worthen", email = "deworthe@microsoft.com"}, + {name = "Gaudy Blanco Meneses", email = "gaudyb@microsoft.com"}, + {name = "Ha Trinh", email = "trinhha@microsoft.com"}, + {name = "Jonathan Larson", email = "jolarso@microsoft.com"}, + {name = "Josh Bradley", email = "joshbradley@microsoft.com"}, + {name = "Kate Lytvynets", email = "kalytv@microsoft.com"}, + {name = "Kenny Zhang", email = "zhangken@microsoft.com"}, + {name = "Mónica Carvajal"}, + {name = "Nathan Evans", email = "naevans@microsoft.com"}, + {name = "Rodrigo Racanicci", email = "rracanicci@microsoft.com"}, + {name = "Sarah Smith", email = "smithsarah@microsoft.com"}, +] +license = {text = "MIT"} +readme = "README.md" +requires-python = ">=3.11,<3.14" +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "azure-core~=1.32", + "azure-cosmos~=4.9", + "azure-identity~=1.19", + "azure-search-documents~=11.6", + "graphrag-common==2.7.0", + "lancedb~=0.24.1", + "numpy~=2.2", + "pyarrow~=22.0", + "pydantic~=2.10", +] + +[project.urls] +Source = "https://github.com/microsoft/graphrag" + +[build-system] +requires = ["hatchling>=1.27.0,<2.0.0"] +build-backend = "hatchling.build" diff --git a/packages/graphrag/graphrag/api/query.py b/packages/graphrag/graphrag/api/query.py index f573559bcb..fe50d56a45 100644 --- a/packages/graphrag/graphrag/api/query.py +++ b/packages/graphrag/graphrag/api/query.py @@ -294,7 +294,7 @@ def local_search_streaming( logger.debug(msg) description_embedding_store = get_embedding_store( - store=config.vector_store.model_dump(), + config=config.vector_store, embedding_name=entity_description_embedding, ) @@ -419,12 +419,12 @@ def drift_search_streaming( logger.debug(msg) description_embedding_store = get_embedding_store( - store=config.vector_store.model_dump(), + config=config.vector_store, embedding_name=entity_description_embedding, ) full_content_embedding_store = get_embedding_store( - store=config.vector_store.model_dump(), + config=config.vector_store, embedding_name=community_full_content_embedding, ) @@ -528,7 +528,7 @@ def basic_search_streaming( logger.debug(msg) embedding_store = get_embedding_store( - store=config.vector_store.model_dump(), + config=config.vector_store, embedding_name=text_unit_text_embedding, ) diff --git a/packages/graphrag/graphrag/config/defaults.py b/packages/graphrag/graphrag/config/defaults.py index cc07438c85..7f7fd3f9e8 100644 --- a/packages/graphrag/graphrag/config/defaults.py +++ b/packages/graphrag/graphrag/config/defaults.py @@ -11,6 +11,7 @@ from graphrag_chunking.chunk_strategy_type import ChunkerType from graphrag_input import InputType from graphrag_storage import StorageType +from graphrag_vectors import VectorStoreType from graphrag.config.embeddings import default_embeddings from graphrag.config.enums import ( @@ -19,7 +20,6 @@ ModelType, NounPhraseExtractorType, ReportingType, - VectorStoreType, ) from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( EN_STOP_WORDS, @@ -373,13 +373,6 @@ class VectorStoreDefaults: type: ClassVar[str] = VectorStoreType.LanceDB.value db_uri: str = str(Path(DEFAULT_OUTPUT_BASE_DIR) / "lancedb") - overwrite: bool = True - index_prefix: str = "" - url: None = None - api_key: None = None - audience: None = None - database_name: None = None - schema: None = None @dataclass diff --git a/packages/graphrag/graphrag/config/embeddings.py b/packages/graphrag/graphrag/config/embeddings.py index c06567acfa..7456ffca5a 100644 --- a/packages/graphrag/graphrag/config/embeddings.py +++ b/packages/graphrag/graphrag/config/embeddings.py @@ -17,25 +17,3 @@ community_full_content_embedding, text_unit_text_embedding, ] - - -def create_index_name( - index_prefix: str, embedding_name: str, validate: bool = True -) -> str: - """ - Create a index name for the embedding store. - - Within any given vector store, we can have multiple sets of embeddings organized into projects. - The `container` param is used for this partitioning, and is added as a index_prefix to the index name for differentiation. - - The embedding name is fixed, with the available list defined in graphrag.index.config.embeddings - - Note that we use dot notation in our names, but many vector stores do not support this - so we convert to dashes. - """ - if validate and embedding_name not in all_embeddings: - msg = f"Invalid embedding name: {embedding_name}" - raise KeyError(msg) - - if index_prefix: - return f"{index_prefix}-{embedding_name}" - return embedding_name diff --git a/packages/graphrag/graphrag/config/enums.py b/packages/graphrag/graphrag/config/enums.py index 9685ebb07b..8389a724cd 100644 --- a/packages/graphrag/graphrag/config/enums.py +++ b/packages/graphrag/graphrag/config/enums.py @@ -8,14 +8,6 @@ from enum import Enum -class VectorStoreType(str, Enum): - """The supported vector store types.""" - - LanceDB = "lancedb" - AzureAISearch = "azure_ai_search" - CosmosDB = "cosmosdb" - - class ReportingType(str, Enum): """The reporting configuration type for the pipeline.""" diff --git a/packages/graphrag/graphrag/config/models/graph_rag_config.py b/packages/graphrag/graphrag/config/models/graph_rag_config.py index 2e6efd4461..0bb74b2385 100644 --- a/packages/graphrag/graphrag/config/models/graph_rag_config.py +++ b/packages/graphrag/graphrag/config/models/graph_rag_config.py @@ -11,10 +11,12 @@ from graphrag_chunking.chunking_config import ChunkingConfig from graphrag_input import InputConfig from graphrag_storage import StorageConfig, StorageType +from graphrag_vectors import IndexSchema, VectorStoreConfig, VectorStoreType from pydantic import BaseModel, Field, model_validator from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.enums import ReportingType, VectorStoreType +from graphrag.config.embeddings import all_embeddings +from graphrag.config.enums import ReportingType from graphrag.config.models.basic_search_config import BasicSearchConfig from graphrag.config.models.cluster_graph_config import ClusterGraphConfig from graphrag.config.models.community_reports_config import CommunityReportsConfig @@ -32,7 +34,6 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) -from graphrag.config.models.vector_store_config import VectorStoreConfig from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import ( RateLimiterFactory, ) @@ -276,13 +277,26 @@ def _validate_reporting_base_dir(self) -> None: ) """The basic search configuration.""" + def _validate_vector_store(self) -> None: + """Validate the vector store configuration specifically in the GraphRAG context. This checks and sets required dynamic defaults for the embeddings we require.""" + self._validate_vector_store_db_uri() + # check and insert/overlay schemas for all of the core embeddings + # note that this does not require that they are used, only that they have a schema + # the embed_text block has the list of actual embeddings + if not self.vector_store.index_schema: + self.vector_store.index_schema = {} + for embedding in all_embeddings: + if embedding not in self.vector_store.index_schema: + self.vector_store.index_schema[embedding] = IndexSchema( + index_name=embedding, + ) + def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" store = self.vector_store if store.type == VectorStoreType.LanceDB: if not store.db_uri or store.db_uri.strip == "": - msg = "Vector store URI is required for LanceDB. Please rerun `graphrag init` and set the vector store configuration." - raise ValueError(msg) + store.db_uri = graphrag_config_defaults.vector_store.db_uri store.db_uri = str(Path(store.db_uri).resolve()) def _validate_factories(self) -> None: @@ -321,6 +335,6 @@ def _validate_model(self): self._validate_reporting_base_dir() self._validate_output_base_dir() self._validate_update_output_storage_base_dir() - self._validate_vector_store_db_uri() + self._validate_vector_store() self._validate_factories() return self diff --git a/packages/graphrag/graphrag/config/models/vector_store_config.py b/packages/graphrag/graphrag/config/models/vector_store_config.py deleted file mode 100644 index c2b3e61de1..0000000000 --- a/packages/graphrag/graphrag/config/models/vector_store_config.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Parameterization settings for the default configuration.""" - -from pydantic import BaseModel, Field, model_validator - -from graphrag.config.defaults import vector_store_defaults -from graphrag.config.embeddings import all_embeddings -from graphrag.config.enums import VectorStoreType -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig - - -class VectorStoreConfig(BaseModel): - """The default configuration section for Vector Store.""" - - type: str = Field( - description="The vector store type to use.", - default=vector_store_defaults.type, - ) - - db_uri: str | None = Field( - description="The database URI to use.", - default=None, - ) - - def _validate_db_uri(self) -> None: - """Validate the database URI.""" - if self.type == VectorStoreType.LanceDB.value and ( - self.db_uri is None or self.db_uri.strip() == "" - ): - self.db_uri = vector_store_defaults.db_uri - - if self.type != VectorStoreType.LanceDB.value and ( - self.db_uri is not None and self.db_uri.strip() != "" - ): - msg = "vector_store.db_uri is only used when vector_store.type == lancedb. Please rerun `graphrag init` and select the correct vector store type." - raise ValueError(msg) - - url: str | None = Field( - description="The database URL when type == azure_ai_search.", - default=vector_store_defaults.url, - ) - - def _validate_url(self) -> None: - """Validate the database URL.""" - if self.type == VectorStoreType.AzureAISearch and ( - self.url is None or self.url.strip() == "" - ): - msg = "vector_store.url is required when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type." - raise ValueError(msg) - - if self.type == VectorStoreType.CosmosDB and ( - self.url is None or self.url.strip() == "" - ): - msg = "vector_store.url is required when vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." - raise ValueError(msg) - - if self.type == VectorStoreType.LanceDB and ( - self.url is not None and self.url.strip() != "" - ): - msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." - raise ValueError(msg) - - api_key: str | None = Field( - description="The database API key when type == azure_ai_search.", - default=vector_store_defaults.api_key, - ) - - audience: str | None = Field( - description="The database audience when type == azure_ai_search.", - default=vector_store_defaults.audience, - ) - - index_prefix: str | None = Field( - description="The index prefix to use.", - default=vector_store_defaults.index_prefix, - ) - - database_name: str | None = Field( - description="The database name to use when type == cosmos_db.", - default=vector_store_defaults.database_name, - ) - - embeddings_schema: dict[str, VectorStoreSchemaConfig] = {} - - def _validate_embeddings_schema(self) -> None: - """Validate the embeddings schema.""" - for name in self.embeddings_schema: - if name not in all_embeddings: - msg = f"vector_store.embeddings_schema contains an invalid embedding schema name: {name}. Please update your settings.yaml and select the correct embedding schema names." - raise ValueError(msg) - - if self.type == VectorStoreType.CosmosDB: - for id_field in self.embeddings_schema: - if id_field != "id": - msg = "When using CosmosDB, the id_field in embeddings_schema must be 'id'. Please update your settings.yaml and set the id_field to 'id'." - raise ValueError(msg) - - @model_validator(mode="after") - def _validate_model(self): - """Validate the model.""" - self._validate_db_uri() - self._validate_url() - self._validate_embeddings_schema() - return self diff --git a/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py b/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py index 7ca6c80920..2e075ac655 100644 --- a/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py +++ b/packages/graphrag/graphrag/index/operations/embed_text/embed_text.py @@ -7,12 +7,12 @@ import numpy as np import pandas as pd +from graphrag_vectors import VectorStore, VectorStoreDocument from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.operations.embed_text.run_embed_text import run_embed_text from graphrag.language_model.protocol.base import EmbeddingModel from graphrag.tokenizer.tokenizer import Tokenizer -from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ async def embed_text( batch_size: int, batch_max_tokens: int, num_threads: int, - vector_store: BaseVectorStore, + vector_store: VectorStore, id_column: str = "id", ): """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.""" diff --git a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py index ef24d3e348..104631c451 100644 --- a/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py +++ b/packages/graphrag/graphrag/index/workflows/generate_text_embeddings.py @@ -6,17 +6,18 @@ import logging import pandas as pd +from graphrag_vectors import ( + VectorStoreConfig, + create_vector_store, +) from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.embeddings import ( community_full_content_embedding, - create_index_name, entity_description_embedding, text_unit_text_embedding, ) from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.vector_store_config import VectorStoreConfig -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput @@ -28,8 +29,6 @@ load_table_from_storage, write_table_to_storage, ) -from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.factory import VectorStoreFactory logger = logging.getLogger(__name__) @@ -163,8 +162,10 @@ async def _run_embeddings( vector_store_config: VectorStoreConfig, ) -> pd.DataFrame: """All the steps to generate single embedding.""" - index_name = _get_index_name(vector_store_config, name) - vector_store = _create_vector_store(vector_store_config, index_name, name) + vector_store = create_vector_store( + vector_store_config, vector_store_config.index_schema[name] + ) + vector_store.connect() data["embedding"] = await embed_text( input=data, @@ -179,56 +180,3 @@ async def _run_embeddings( ) return data.loc[:, ["id", "embedding"]] - - -def _create_vector_store( - vector_store_config: VectorStoreConfig, - index_name: str, - embedding_name: str | None = None, -) -> BaseVectorStore: - embeddings_schema: dict[str, VectorStoreSchemaConfig] = ( - vector_store_config.embeddings_schema - ) - - single_embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() - - if ( - embeddings_schema is not None - and embedding_name is not None - and embedding_name in embeddings_schema - ): - raw_config = embeddings_schema[embedding_name] - if isinstance(raw_config, dict): - single_embedding_config = VectorStoreSchemaConfig(**raw_config) - else: - single_embedding_config = raw_config - - if ( - single_embedding_config.index_name is not None - and vector_store_config.index_prefix - ): - single_embedding_config.index_name = ( - f"{vector_store_config.index_prefix}-{single_embedding_config.index_name}" - ) - - if single_embedding_config.index_name is None: - single_embedding_config.index_name = index_name - - args = vector_store_config.model_dump() - args["vector_store_schema_config"] = single_embedding_config - vector_store = VectorStoreFactory().create( - vector_store_config.type, - args, - ) - - vector_store.connect(**args) - return vector_store - - -def _get_index_name(vector_store_config: VectorStoreConfig, embedding_name: str) -> str: - index_prefix = vector_store_config.index_prefix or "" - index_name = create_index_name(index_prefix, embedding_name) - - msg = f"using vector store {vector_store_config.type} with index prefix {index_prefix} for embedding {embedding_name}: {index_name}" - logger.info(msg) - return index_name diff --git a/packages/graphrag/graphrag/query/context_builder/entity_extraction.py b/packages/graphrag/graphrag/query/context_builder/entity_extraction.py index 0dd03ba281..289b3b7ea0 100644 --- a/packages/graphrag/graphrag/query/context_builder/entity_extraction.py +++ b/packages/graphrag/graphrag/query/context_builder/entity_extraction.py @@ -5,6 +5,8 @@ from enum import Enum +from graphrag_vectors import VectorStore + from graphrag.data_model.entity import Entity from graphrag.data_model.relationship import Relationship from graphrag.language_model.protocol.base import EmbeddingModel @@ -13,7 +15,6 @@ get_entity_by_key, get_entity_by_name, ) -from graphrag.vector_stores.base import BaseVectorStore class EntityVectorStoreKey(str, Enum): @@ -36,7 +37,7 @@ def from_string(value: str) -> "EntityVectorStoreKey": def map_query_to_entities( query: str, - text_embedding_vectorstore: BaseVectorStore, + text_embedding_vectorstore: VectorStore, text_embedder: EmbeddingModel, all_entities_dict: dict[str, Entity], embedding_vectorstore_key: str = EntityVectorStoreKey.ID, diff --git a/packages/graphrag/graphrag/query/factory.py b/packages/graphrag/graphrag/query/factory.py index 3f73c8c1ab..4ff36100cc 100644 --- a/packages/graphrag/graphrag/query/factory.py +++ b/packages/graphrag/graphrag/query/factory.py @@ -3,6 +3,8 @@ """Query Factory methods to support CLI.""" +from graphrag_vectors import VectorStore + from graphrag.callbacks.query_callbacks import QueryCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.community import Community @@ -33,7 +35,6 @@ ) from graphrag.query.structured_search.local_search.search import LocalSearch from graphrag.tokenizer.get_tokenizer import get_tokenizer -from graphrag.vector_stores.base import BaseVectorStore def get_local_search_engine( @@ -44,7 +45,7 @@ def get_local_search_engine( relationships: list[Relationship], covariates: dict[str, list[Covariate]], response_type: str, - description_embedding_store: BaseVectorStore, + description_embedding_store: VectorStore, system_prompt: str | None = None, callbacks: list[QueryCallbacks] | None = None, ) -> LocalSearch: @@ -198,7 +199,7 @@ def get_drift_search_engine( text_units: list[TextUnit], entities: list[Entity], relationships: list[Relationship], - description_embedding_store: BaseVectorStore, + description_embedding_store: VectorStore, response_type: str, local_system_prompt: str | None = None, reduce_system_prompt: str | None = None, @@ -249,7 +250,7 @@ def get_drift_search_engine( def get_basic_search_engine( text_units: list[TextUnit], - text_unit_embeddings: BaseVectorStore, + text_unit_embeddings: VectorStore, config: GraphRagConfig, response_type: str, system_prompt: str | None = None, diff --git a/packages/graphrag/graphrag/query/indexer_adapters.py b/packages/graphrag/graphrag/query/indexer_adapters.py index 56064d8c2a..f0d5ff7dae 100644 --- a/packages/graphrag/graphrag/query/indexer_adapters.py +++ b/packages/graphrag/graphrag/query/indexer_adapters.py @@ -10,6 +10,7 @@ from typing import cast import pandas as pd +from graphrag_vectors import VectorStore from graphrag.data_model.community import Community from graphrag.data_model.community_report import CommunityReport @@ -26,7 +27,6 @@ read_relationships, read_text_units, ) -from graphrag.vector_stores.base import BaseVectorStore logger = logging.getLogger(__name__) @@ -103,7 +103,7 @@ def read_indexer_reports( def read_indexer_report_embeddings( community_reports: list[CommunityReport], - embeddings_store: BaseVectorStore, + embeddings_store: VectorStore, ): """Read in the Community Reports from the raw indexing outputs.""" for report in community_reports: diff --git a/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py b/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py index b7390017fc..57bb16efca 100644 --- a/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py +++ b/packages/graphrag/graphrag/query/structured_search/basic_search/basic_context.py @@ -7,6 +7,7 @@ from typing import cast import pandas as pd +from graphrag_vectors import VectorStore from graphrag.data_model.text_unit import TextUnit from graphrag.language_model.protocol.base import EmbeddingModel @@ -17,7 +18,6 @@ from graphrag.query.context_builder.conversation_history import ConversationHistory from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.tokenizer.tokenizer import Tokenizer -from graphrag.vector_stores.base import BaseVectorStore logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class BasicSearchContext(BasicContextBuilder): def __init__( self, text_embedder: EmbeddingModel, - text_unit_embeddings: BaseVectorStore, + text_unit_embeddings: VectorStore, text_units: list[TextUnit] | None = None, tokenizer: Tokenizer | None = None, embedding_vectorstore_key: str = "id", diff --git a/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py b/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py index 9e1e9c317a..41649f2f6c 100644 --- a/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py +++ b/packages/graphrag/graphrag/query/structured_search/drift_search/drift_context.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from graphrag_vectors import VectorStore from graphrag.config.models.drift_search_config import DRIFTSearchConfig from graphrag.data_model.community_report import CommunityReport @@ -29,7 +30,6 @@ ) from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.tokenizer.tokenizer import Tokenizer -from graphrag.vector_stores.base import BaseVectorStore logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def __init__( config: DRIFTSearchConfig, text_embedder: EmbeddingModel, entities: list[Entity], - entity_text_embeddings: BaseVectorStore, + entity_text_embeddings: VectorStore, text_units: list[TextUnit] | None = None, reports: list[CommunityReport] | None = None, relationships: list[Relationship] | None = None, diff --git a/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py b/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py index b91272d164..a3540dfcb2 100644 --- a/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py +++ b/packages/graphrag/graphrag/query/structured_search/local_search/mixed_context.py @@ -7,6 +7,7 @@ from typing import Any import pandas as pd +from graphrag_vectors import VectorStore from graphrag.data_model.community_report import CommunityReport from graphrag.data_model.covariate import Covariate @@ -42,7 +43,6 @@ from graphrag.query.structured_search.base import LocalContextBuilder from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.tokenizer.tokenizer import Tokenizer -from graphrag.vector_stores.base import BaseVectorStore logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ class LocalSearchMixedContext(LocalContextBuilder): def __init__( self, entities: list[Entity], - entity_text_embeddings: BaseVectorStore, + entity_text_embeddings: VectorStore, text_embedder: EmbeddingModel, text_units: list[TextUnit] | None = None, community_reports: list[CommunityReport] | None = None, diff --git a/packages/graphrag/graphrag/utils/api.py b/packages/graphrag/graphrag/utils/api.py index 2aef931170..980997db06 100644 --- a/packages/graphrag/graphrag/utils/api.py +++ b/packages/graphrag/graphrag/utils/api.py @@ -4,48 +4,21 @@ """API functions for the GraphRAG module.""" from pathlib import Path -from typing import Any -from graphrag.config.embeddings import create_index_name -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.vector_stores.base import ( - BaseVectorStore, +from graphrag_vectors import ( + VectorStore, + VectorStoreConfig, + create_vector_store, ) -from graphrag.vector_stores.factory import VectorStoreFactory def get_embedding_store( - store: dict[str, Any], + config: VectorStoreConfig, embedding_name: str, -) -> BaseVectorStore: - """Get the embedding description store.""" - vector_store_type = store["type"] - index_name = create_index_name(store.get("index_prefix", ""), embedding_name) - - embeddings_schema: dict[str, VectorStoreSchemaConfig] = store.get( - "embeddings_schema", {} - ) - embedding_config: VectorStoreSchemaConfig = VectorStoreSchemaConfig() - - if ( - embeddings_schema is not None - and embedding_name is not None - and embedding_name in embeddings_schema - ): - raw_config = embeddings_schema[embedding_name] - if isinstance(raw_config, dict): - embedding_config = VectorStoreSchemaConfig(**raw_config) - else: - embedding_config = raw_config - - if embedding_config.index_name is None: - embedding_config.index_name = index_name - - embedding_store = VectorStoreFactory().create( - vector_store_type, - {**store, "vector_store_schema_config": embedding_config}, - ) - embedding_store.connect(**store) +) -> VectorStore: + """Get the embedding store.""" + embedding_store = create_vector_store(config, config.index_schema[embedding_name]) + embedding_store.connect() return embedding_store diff --git a/packages/graphrag/graphrag/vector_stores/__init__.py b/packages/graphrag/graphrag/vector_stores/__init__.py deleted file mode 100644 index 4f137d07bb..0000000000 --- a/packages/graphrag/graphrag/vector_stores/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A package containing vector store implementations.""" diff --git a/packages/graphrag/graphrag/vector_stores/factory.py b/packages/graphrag/graphrag/vector_stores/factory.py deleted file mode 100644 index 00a451771c..0000000000 --- a/packages/graphrag/graphrag/vector_stores/factory.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Factory functions for creating a vector store.""" - -from __future__ import annotations - -from graphrag_common.factory import Factory - -from graphrag.config.enums import VectorStoreType -from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore -from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore -from graphrag.vector_stores.lancedb import LanceDBVectorStore - - -class VectorStoreFactory(Factory[BaseVectorStore]): - """A factory for vector stores. - - Includes a method for users to register a custom vector store implementation. - - Configuration arguments are passed to each vector store implementation as kwargs - for individual enforcement of required/optional arguments. - """ - - -# --- register built-in vector store implementations --- -vector_store_factory = VectorStoreFactory() -vector_store_factory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore) -vector_store_factory.register( - VectorStoreType.AzureAISearch.value, AzureAISearchVectorStore -) -vector_store_factory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore) diff --git a/packages/graphrag/pyproject.toml b/packages/graphrag/pyproject.toml index 769634bcb3..2db8c4f835 100644 --- a/packages/graphrag/pyproject.toml +++ b/packages/graphrag/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "graphrag-common==2.7.0", "graphrag-input==2.7.0", "graphrag-storage==2.7.0", + "graphrag-vectors==2.7.0", "graspologic-native~=1.2", "json-repair~=0.30", "lancedb~=0.24.1", diff --git a/pyproject.toml b/pyproject.toml index b2ec49f4c1..a14efa5ab0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ graphrag-common = { workspace = true } graphrag-input = { workspace = true } graphrag-storage = { workspace = true } graphrag-cache = { workspace = true } +graphrag-vectors = { workspace = true } # Keep poethepoet for task management to minimize changes [tool.poe.tasks] @@ -77,6 +78,7 @@ _semversioner_update_graphrag_common_toml_version = "update-toml update --file p _semversioner_update_graphrag_storage_toml_version = "update-toml update --file packages/graphrag-storage/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_graphrag_cache_toml_version = "update-toml update --file packages/graphrag-cache/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_graphrag_input_toml_version = "update-toml update --file packages/graphrag-input/pyproject.toml --path project.version --value $(uv run semversioner current-version)" +_semversioner_update_graphrag_vectors_toml_version = "update-toml update --file packages/graphrag-vectors/pyproject.toml --path project.version --value $(uv run semversioner current-version)" _semversioner_update_workspace_dependency_versions = "python -m scripts.update_workspace_dependency_versions" semversioner_add = "semversioner add-change" coverage_report = 'coverage report --omit "**/tests/**" --show-missing' diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index c6ebb2248e..bf7aa59644 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -32,7 +32,6 @@ vector_store: type: "lancedb" db_uri: "./tests/fixtures/min-csv/lancedb" container_name: "lancedb_ci" - overwrite: True input: type: csv diff --git a/tests/integration/vector_stores/test_azure_ai_search.py b/tests/integration/vector_stores/test_azure_ai_search.py index 9f35f62fbd..ffd445508c 100644 --- a/tests/integration/vector_stores/test_azure_ai_search.py +++ b/tests/integration/vector_stores/test_azure_ai_search.py @@ -7,9 +7,10 @@ from unittest.mock import MagicMock, patch import pytest -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore -from graphrag.vector_stores.base import VectorStoreDocument +from graphrag_vectors import ( + VectorStoreDocument, +) +from graphrag_vectors.azure_ai_search import AzureAISearchVectorStore TEST_AZURE_AI_SEARCH_URL = os.environ.get( "TEST_AZURE_AI_SEARCH_URL", "https://test-url.search.windows.net" @@ -23,58 +24,49 @@ class TestAzureAISearchVectorStore: @pytest.fixture def mock_search_client(self): """Create a mock Azure AI Search client.""" - with patch( - "graphrag.vector_stores.azure_ai_search.SearchClient" - ) as mock_client: + with patch("graphrag_vectors.azure_ai_search.SearchClient") as mock_client: yield mock_client.return_value @pytest.fixture def mock_index_client(self): """Create a mock Azure AI Search index client.""" - with patch( - "graphrag.vector_stores.azure_ai_search.SearchIndexClient" - ) as mock_client: + with patch("graphrag_vectors.azure_ai_search.SearchIndexClient") as mock_client: yield mock_client.return_value @pytest.fixture def vector_store(self, mock_search_client, mock_index_client): """Create an Azure AI Search vector store instance.""" vector_store = AzureAISearchVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="test_vectors", vector_size=5 - ), + url=TEST_AZURE_AI_SEARCH_URL, + api_key=TEST_AZURE_AI_SEARCH_KEY, + index_name="test_vectors", + vector_size=5, ) # Create the necessary mocks first vector_store.db_connection = mock_search_client vector_store.index_client = mock_index_client - vector_store.connect( - url=TEST_AZURE_AI_SEARCH_URL, - api_key=TEST_AZURE_AI_SEARCH_KEY, - ) + vector_store.connect() return vector_store @pytest.fixture def vector_store_custom(self, mock_search_client, mock_index_client): """Create an Azure AI Search vector store instance.""" vector_store = AzureAISearchVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="test_vectors", - id_field="id_custom", - vector_field="vector_custom", - vector_size=5, - ), + url=TEST_AZURE_AI_SEARCH_URL, + api_key=TEST_AZURE_AI_SEARCH_KEY, + index_name="test_vectors", + id_field="id_custom", + vector_field="vector_custom", + vector_size=5, ) # Create the necessary mocks first vector_store.db_connection = mock_search_client vector_store.index_client = mock_index_client - vector_store.connect( - url=TEST_AZURE_AI_SEARCH_URL, - api_key=TEST_AZURE_AI_SEARCH_KEY, - ) + vector_store.connect() return vector_store @pytest.fixture diff --git a/tests/integration/vector_stores/test_cosmosdb.py b/tests/integration/vector_stores/test_cosmosdb.py index 331ef1f297..de30bb6c50 100644 --- a/tests/integration/vector_stores/test_cosmosdb.py +++ b/tests/integration/vector_stores/test_cosmosdb.py @@ -7,9 +7,10 @@ import numpy as np import pytest -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.vector_stores.base import VectorStoreDocument -from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore +from graphrag_vectors import ( + VectorStoreDocument, +) +from graphrag_vectors.cosmosdb import CosmosDBVectorStore # cspell:disable-next-line well-known-key WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://127.0.0.1:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" @@ -24,14 +25,13 @@ def test_vector_store_operations(): """Test basic vector store operations with CosmosDB.""" vector_store = CosmosDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig(index_name="testvector"), + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="test_db", + index_name="testvector", ) try: - vector_store.connect( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - database_name="test_db", - ) + vector_store.connect() docs = [ VectorStoreDocument( @@ -72,13 +72,12 @@ def mock_embedder(text: str) -> list[float]: def test_clear(): """Test clearing the vector store.""" vector_store = CosmosDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig(index_name="testclear"), + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="testclear", + index_name="testclear", ) try: - vector_store.connect( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - database_name="testclear", - ) + vector_store.connect() doc = VectorStoreDocument( id="test", @@ -100,19 +99,16 @@ def test_clear(): def test_vector_store_customization(): """Test vector store customization with CosmosDB.""" vector_store = CosmosDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="text-embeddings", - id_field="id", - vector_field="vector_custom", - vector_size=5, - ), + connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, + database_name="test_db", + index_name="text-embeddings", + id_field="id", + vector_field="vector_custom", + vector_size=5, ) try: - vector_store.connect( - connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING, - database_name="test_db", - ) + vector_store.connect() docs = [ VectorStoreDocument( diff --git a/tests/integration/vector_stores/test_factory.py b/tests/integration/vector_stores/test_factory.py index a8b797dbb2..ceb2a84c84 100644 --- a/tests/integration/vector_stores/test_factory.py +++ b/tests/integration/vector_stores/test_factory.py @@ -6,28 +6,28 @@ """ import pytest -from graphrag.config.enums import VectorStoreType -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore -from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore -from graphrag.vector_stores.factory import VectorStoreFactory -from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag_vectors import ( + VectorStore, + VectorStoreFactory, + VectorStoreType, +) +from graphrag_vectors.azure_ai_search import AzureAISearchVectorStore +from graphrag_vectors.cosmosdb import CosmosDBVectorStore +from graphrag_vectors.lancedb import LanceDBVectorStore + +# register the defaults, since they are lazily registered +VectorStoreFactory().register(VectorStoreType.LanceDB, LanceDBVectorStore) +VectorStoreFactory().register(VectorStoreType.AzureAISearch, AzureAISearchVectorStore) +VectorStoreFactory().register(VectorStoreType.CosmosDB, CosmosDBVectorStore) def test_create_lancedb_vector_store(): kwargs = { "db_uri": "/tmp/lancedb", - "vector_store_schema_config": VectorStoreSchemaConfig( - index_name="test_collection" - ), } - vector_store = VectorStoreFactory().create( - VectorStoreType.LanceDB.value, - kwargs, - ) + vector_store = VectorStoreFactory().create(VectorStoreType.LanceDB, kwargs) assert isinstance(vector_store, LanceDBVectorStore) - assert vector_store.index_name == "test_collection" + assert vector_store.index_name == "vector_index" @pytest.mark.skip(reason="Azure AI Search requires credentials and setup") @@ -35,12 +35,10 @@ def test_create_azure_ai_search_vector_store(): kwargs = { "url": "https://test.search.windows.net", "api_key": "test_key", - "vector_store_schema_config": VectorStoreSchemaConfig( - index_name="test_collection" - ), + "index_name": "test_collection", } vector_store = VectorStoreFactory().create( - VectorStoreType.AzureAISearch.value, + VectorStoreType.AzureAISearch, kwargs, ) assert isinstance(vector_store, AzureAISearchVectorStore) @@ -51,13 +49,11 @@ def test_create_cosmosdb_vector_store(): kwargs = { "connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==", "database_name": "test_db", - "vector_store_schema_config": VectorStoreSchemaConfig( - index_name="test_collection" - ), + "index_name": "test_collection", } vector_store = VectorStoreFactory().create( - VectorStoreType.CosmosDB.value, + VectorStoreType.CosmosDB, kwargs, ) @@ -68,8 +64,8 @@ def test_register_and_create_custom_vector_store(): """Test registering and creating a custom vector store type.""" from unittest.mock import MagicMock - # Create a mock that satisfies the BaseVectorStore interface - custom_vector_store_class = MagicMock(spec=BaseVectorStore) + # Create a mock that satisfies the VectorStore interface + custom_vector_store_class = MagicMock(spec=VectorStore) # Make the mock return a mock instance when instantiated instance = MagicMock() instance.initialized = True @@ -79,9 +75,7 @@ def test_register_and_create_custom_vector_store(): "custom", lambda **kwargs: custom_vector_store_class(**kwargs) ) - vector_store = VectorStoreFactory().create( - "custom", {"vector_store_schema_config": VectorStoreSchemaConfig()} - ) + vector_store = VectorStoreFactory().create("custom", {}) assert custom_vector_store_class.called assert vector_store is instance @@ -99,9 +93,9 @@ def test_create_unknown_vector_store(): def test_is_supported_type(): # Test built-in types - assert VectorStoreType.LanceDB.value in VectorStoreFactory() - assert VectorStoreType.AzureAISearch.value in VectorStoreFactory() - assert VectorStoreType.CosmosDB.value in VectorStoreFactory() + assert VectorStoreType.LanceDB in VectorStoreFactory() + assert VectorStoreType.AzureAISearch in VectorStoreFactory() + assert VectorStoreType.CosmosDB in VectorStoreFactory() # Test unknown type assert "unknown" not in VectorStoreFactory() @@ -109,9 +103,9 @@ def test_is_supported_type(): def test_register_class_directly_works(): """Test that registering a class directly works.""" - from graphrag.vector_stores.base import BaseVectorStore + from graphrag_vectors import VectorStore - class CustomVectorStore(BaseVectorStore): + class CustomVectorStore(VectorStore): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -131,7 +125,7 @@ def similarity_search_by_text(self, text, text_embedder, k=10, **kwargs): return [] def search_by_id(self, id): - from graphrag.vector_stores.base import VectorStoreDocument + from graphrag_vectors import VectorStoreDocument return VectorStoreDocument(id=id, vector=None) @@ -144,7 +138,7 @@ def search_by_id(self, id): # Test creating an instance vector_store = VectorStoreFactory().create( "custom_class", - {"vector_store_schema_config": VectorStoreSchemaConfig()}, + {}, ) assert isinstance(vector_store, CustomVectorStore) diff --git a/tests/integration/vector_stores/test_lancedb.py b/tests/integration/vector_stores/test_lancedb.py index 628c89f012..dbd1198556 100644 --- a/tests/integration/vector_stores/test_lancedb.py +++ b/tests/integration/vector_stores/test_lancedb.py @@ -8,9 +8,10 @@ import numpy as np import pytest -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.vector_stores.base import VectorStoreDocument -from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag_vectors import ( + VectorStoreDocument, +) +from graphrag_vectors.lancedb import LanceDBVectorStore class TestLanceDBVectorStore: @@ -58,11 +59,9 @@ def test_vector_store_operations(self, sample_documents): temp_dir = tempfile.mkdtemp() try: vector_store = LanceDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="test_collection", vector_size=5 - ) + db_uri=temp_dir, index_name="test_collection", vector_size=5 ) - vector_store.connect(db_uri=temp_dir) + vector_store.connect() vector_store.create_index() vector_store.load_documents(sample_documents[:2]) @@ -111,11 +110,9 @@ def test_empty_collection(self): temp_dir = tempfile.mkdtemp() try: vector_store = LanceDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="empty_collection", vector_size=5 - ) + db_uri=temp_dir, index_name="empty_collection", vector_size=5 ) - vector_store.connect(db_uri=temp_dir) + vector_store.connect() # Load the vector store with a document, then delete it sample_doc = VectorStoreDocument( @@ -154,12 +151,10 @@ def test_filter_search(self, sample_documents_categories): temp_dir = tempfile.mkdtemp() try: vector_store = LanceDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="filter_collection", vector_size=5 - ) + db_uri=temp_dir, index_name="filter_collection", vector_size=5 ) - vector_store.connect(db_uri=temp_dir) + vector_store.connect() vector_store.create_index() vector_store.load_documents(sample_documents_categories) @@ -181,14 +176,13 @@ def test_vector_store_customization(self, sample_documents): temp_dir = tempfile.mkdtemp() try: vector_store = LanceDBVectorStore( - vector_store_schema_config=VectorStoreSchemaConfig( - index_name="text-embeddings", - id_field="id_custom", - vector_field="vector_custom", - vector_size=5, - ), + db_uri=temp_dir, + index_name="text-embeddings", + id_field="id_custom", + vector_field="vector_custom", + vector_size=5, ) - vector_store.connect(db_uri=temp_dir) + vector_store.connect() vector_store.create_index() vector_store.load_documents(sample_documents[:2]) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index c02285125a..813502d548 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -25,11 +25,11 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) -from graphrag.config.models.vector_store_config import VectorStoreConfig from graphrag_cache import CacheConfig from graphrag_chunking.chunking_config import ChunkingConfig from graphrag_input import InputConfig from graphrag_storage import StorageConfig +from graphrag_vectors import VectorStoreConfig from pydantic import BaseModel FAKE_API_KEY = "NOT_AN_API_KEY" @@ -112,7 +112,6 @@ def assert_vector_store_configs( assert actual.url == expected.url assert actual.api_key == expected.api_key assert actual.audience == expected.audience - assert actual.index_prefix == expected.index_prefix assert actual.database_name == expected.database_name diff --git a/tests/unit/query/context_builder/test_entity_extraction.py b/tests/unit/query/context_builder/test_entity_extraction.py index 1fd5e41474..0d1c7c0018 100644 --- a/tests/unit/query/context_builder/test_entity_extraction.py +++ b/tests/unit/query/context_builder/test_entity_extraction.py @@ -3,26 +3,23 @@ from typing import Any -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.entity import Entity -from graphrag.data_model.types import TextEmbedder from graphrag.language_model.manager import ModelManager from graphrag.query.context_builder.entity_extraction import ( EntityVectorStoreKey, map_query_to_entities, ) -from graphrag.vector_stores.base import ( - BaseVectorStore, +from graphrag_vectors import ( + TextEmbedder, + VectorStore, VectorStoreDocument, VectorStoreSearchResult, ) -class MockBaseVectorStore(BaseVectorStore): +class MockVectorStore(VectorStore): def __init__(self, documents: list[VectorStoreDocument]) -> None: - super().__init__( - vector_store_schema_config=VectorStoreSchemaConfig(index_name="mock") - ) + super().__init__(index_name="mock") self.documents = documents def connect(self, **kwargs: Any) -> None: @@ -92,7 +89,7 @@ def test_map_query_to_entities(): assert map_query_to_entities( query="t22", - text_embedding_vectorstore=MockBaseVectorStore([ + text_embedding_vectorstore=MockVectorStore([ VectorStoreDocument(id=entity.title, vector=None) for entity in entities ]), text_embedder=ModelManager().get_or_create_embedding_model( @@ -113,7 +110,7 @@ def test_map_query_to_entities(): assert map_query_to_entities( query="", - text_embedding_vectorstore=MockBaseVectorStore([ + text_embedding_vectorstore=MockVectorStore([ VectorStoreDocument(id=entity.id, vector=None) for entity in entities ]), text_embedder=ModelManager().get_or_create_embedding_model( diff --git a/tests/unit/utils/test_embeddings.py b/tests/unit/utils/test_embeddings.py deleted file mode 100644 index 63d0619c1d..0000000000 --- a/tests/unit/utils/test_embeddings.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -import pytest -from graphrag.config.embeddings import create_index_name - - -def test_create_index_name(): - collection = create_index_name("default", "entity_description") - assert collection == "default-entity_description" - - -def test_create_index_name_invalid_embedding_throws(): - with pytest.raises(KeyError): - create_index_name("default", "invalid.name") - - -def test_create_index_name_invalid_embedding_does_not_throw(): - collection = create_index_name("default", "invalid_name", validate=False) - assert collection == "default-invalid_name" diff --git a/uv.lock b/uv.lock index 72a6db6f11..ddff9f008b 100644 --- a/uv.lock +++ b/uv.lock @@ -19,6 +19,7 @@ members = [ "graphrag-input", "graphrag-monorepo", "graphrag-storage", + "graphrag-vectors", ] [[package]] @@ -1009,6 +1010,7 @@ dependencies = [ { name = "graphrag-common" }, { name = "graphrag-input" }, { name = "graphrag-storage" }, + { name = "graphrag-vectors" }, { name = "graspologic-native" }, { name = "json-repair" }, { name = "lancedb" }, @@ -1043,6 +1045,7 @@ requires-dist = [ { name = "graphrag-common", editable = "packages/graphrag-common" }, { name = "graphrag-input", editable = "packages/graphrag-input" }, { name = "graphrag-storage", editable = "packages/graphrag-storage" }, + { name = "graphrag-vectors", editable = "packages/graphrag-vectors" }, { name = "graspologic-native", specifier = "~=1.2" }, { name = "json-repair", specifier = "~=0.30" }, { name = "lancedb", specifier = "~=0.24.1" }, @@ -1204,6 +1207,35 @@ requires-dist = [ { name = "pydantic", specifier = "~=2.10" }, ] +[[package]] +name = "graphrag-vectors" +version = "2.7.0" +source = { editable = "packages/graphrag-vectors" } +dependencies = [ + { name = "azure-core" }, + { name = "azure-cosmos" }, + { name = "azure-identity" }, + { name = "azure-search-documents" }, + { name = "graphrag-common" }, + { name = "lancedb" }, + { name = "numpy" }, + { name = "pyarrow" }, + { name = "pydantic" }, +] + +[package.metadata] +requires-dist = [ + { name = "azure-core", specifier = "~=1.32" }, + { name = "azure-cosmos", specifier = "~=4.9" }, + { name = "azure-identity", specifier = "~=1.19" }, + { name = "azure-search-documents", specifier = "~=11.6" }, + { name = "graphrag-common", editable = "packages/graphrag-common" }, + { name = "lancedb", specifier = "~=0.24.1" }, + { name = "numpy", specifier = "~=2.2" }, + { name = "pyarrow", specifier = "~=22.0" }, + { name = "pydantic", specifier = "~=2.10" }, +] + [[package]] name = "graspologic-native" version = "1.2.5"