diff --git a/docs/config/yaml.md b/docs/config/yaml.md index 791b67341c..c4c150f0be 100644 --- a/docs/config/yaml.md +++ b/docs/config/yaml.md @@ -163,7 +163,7 @@ Where to put all vectors for the system. Configured for lancedb by default. This - `type` **lancedb|azure_ai_search|cosmosdb** - Type of vector store. Default=`lancedb` - `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb` -- `url` **str** (only for AI Search) - AI Search endpoint +- `url` **str** (only for AI Search or cosmosdb) - AI Search/Cosmos DB endpoint - `api_key` **str** (optional - only for AI Search) - The AI Search api key to use. - `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used. - `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default` diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index c4c5b780c3..bb09251a5b 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -282,6 +282,16 @@ def _validate_reporting_base_dir(self) -> None: ) """The basic search configuration.""" + user_id: str = Field( + description="User ID", default="" + ) + """The user ID.""" + + def _validate_user_id(self) -> None: + if not self.user_id: + msg = "User ID is required." + raise ValueError(msg) + def _validate_vector_store_db_uri(self) -> None: """Validate the vector store configuration.""" for store in self.vector_store.values(): @@ -350,4 +360,5 @@ def _validate_model(self): self._validate_multi_output_base_dirs() self._validate_update_index_output_base_dir() self._validate_vector_store_db_uri() + self._validate_user_id() return self diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 935644b025..9f6afc463c 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -42,6 +42,7 @@ async def embed_text( embed_column: str, strategy: dict, embedding_name: str, + user_id: str, id_column: str = "id", title_column: str | None = None, ): @@ -66,6 +67,7 @@ async def embed_text( vector_store_config=vector_store_workflow_config, id_column=id_column, title_column=title_column, + doc_user_id=user_id ) return await _text_embed_in_memory( @@ -102,6 +104,7 @@ async def _text_embed_with_vector_store( strategy: dict[str, Any], vector_store: BaseVectorStore, vector_store_config: dict, + doc_user_id: str, id_column: str = "id", title_column: str | None = None, ): @@ -169,6 +172,7 @@ async def _text_embed_with_vector_store( text=doc_text, vector=doc_vector, attributes={"title": doc_title}, + user_id=doc_user_id, ) documents.append(document) diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 98bc4e6692..c8088d225d 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -56,6 +56,7 @@ async def run_workflow( cache=context.cache, text_embed_config=text_embed, embedded_fields=embedded_fields, + user_id=config.user_id ) if config.snapshots.embeddings: @@ -79,6 +80,7 @@ async def generate_text_embeddings( cache: PipelineCache, text_embed_config: dict, embedded_fields: set[str], + user_id: str ) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" embedding_param_map = { @@ -138,6 +140,7 @@ async def generate_text_embeddings( callbacks=callbacks, cache=cache, text_embed_config=text_embed_config, + user_id=user_id, **embedding_param_map[field], ) return outputs @@ -150,6 +153,7 @@ async def _run_and_snapshot_embeddings( callbacks: WorkflowCallbacks, cache: PipelineCache, text_embed_config: dict, + user_id: str, ) -> pd.DataFrame: """All the steps to generate single embedding.""" data["embedding"] = await embed_text( @@ -159,6 +163,7 @@ async def _run_and_snapshot_embeddings( embed_column=embed_column, embedding_name=name, strategy=text_embed_config["strategy"], + user_id=user_id ) return data.loc[:, ["id", "embedding"]] diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 89c7f9d499..6faf600611 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -74,7 +74,7 @@ def connect(self, **kwargs: Any) -> Any: raise ValueError(not_supported_error) def load_documents( - self, documents: list[VectorStoreDocument], overwrite: bool = True + self, documents: list[VectorStoreDocument], overwrite: bool = True ) -> None: """Load documents into an Azure AI Search index.""" if overwrite: @@ -120,6 +120,10 @@ def load_documents( name="attributes", type=SearchFieldDataType.String, ), + SimpleField( + name="user_id", + type=SearchFieldDataType.String, + ), ], vector_search=vector_search, ) @@ -133,6 +137,7 @@ def load_documents( "vector": doc.vector, "text": doc.text, "attributes": json.dumps(doc.attributes), + "user_id": doc.user_id } for doc in documents if doc.vector is not None @@ -158,7 +163,7 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: return self.query_filter def similarity_search_by_vector( - self, query_embedding: list[float], k: int = 10, **kwargs: Any + self, query_embedding: list[float], k: int = 10, **kwargs: Any ) -> list[VectorStoreSearchResult]: """Perform a vector-based similarity search.""" vectorized_query = VectorizedQuery( @@ -176,6 +181,7 @@ def similarity_search_by_vector( text=doc.get("text", ""), vector=doc.get("vector", []), attributes=(json.loads(doc.get("attributes", "{}"))), + user_id=doc.get("user_id", ""), ), # Cosine similarity between 0.333 and 1.000 # https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results @@ -185,7 +191,7 @@ def similarity_search_by_vector( ] def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[VectorStoreSearchResult]: """Perform a text-based similarity search.""" query_embedding = text_embedder(text) @@ -203,4 +209,5 @@ def search_by_id(self, id: str) -> VectorStoreDocument: text=response.get("text", ""), vector=response.get("vector", []), attributes=(json.loads(response.get("attributes", "{}"))), + user_id=response.get("user_id", ""), ) diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index c4b5e40c42..c5fe55231f 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -19,6 +19,9 @@ class VectorStoreDocument: id: str | int """unique id for the document""" + user_id: str + """unique user id""" + text: str | None vector: list[float] | None