Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Where to put all vectors for the system. Configured for lancedb by default. This

- `type` **lancedb|azure_ai_search|cosmosdb** - Type of vector store. Default=`lancedb`
- `db_uri` **str** (only for lancedb) - The database uri. Default=`storage.base_dir/lancedb`
- `url` **str** (only for AI Search) - AI Search endpoint
- `url` **str** (only for AI Search or cosmosdb) - AI Search/Cosmos DB endpoint
- `api_key` **str** (optional - only for AI Search) - The AI Search api key to use.
- `audience` **str** (only for AI Search) - Audience for managed identity token if managed identity authentication is used.
- `container_name` **str** - The name of a vector container. This stores all indexes (tables) for a given dataset ingest. Default=`default`
Expand Down
11 changes: 11 additions & 0 deletions graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ def _validate_reporting_base_dir(self) -> None:
)
"""The basic search configuration."""

user_id: str = Field(
description="User ID", default=""
)
"""The user ID."""

def _validate_user_id(self) -> None:
if not self.user_id:
msg = "User ID is required."
raise ValueError(msg)

def _validate_vector_store_db_uri(self) -> None:
"""Validate the vector store configuration."""
for store in self.vector_store.values():
Expand Down Expand Up @@ -350,4 +360,5 @@ def _validate_model(self):
self._validate_multi_output_base_dirs()
self._validate_update_index_output_base_dir()
self._validate_vector_store_db_uri()
self._validate_user_id()
return self
4 changes: 4 additions & 0 deletions graphrag/index/operations/embed_text/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def embed_text(
embed_column: str,
strategy: dict,
embedding_name: str,
user_id: str,
id_column: str = "id",
title_column: str | None = None,
):
Expand All @@ -66,6 +67,7 @@ async def embed_text(
vector_store_config=vector_store_workflow_config,
id_column=id_column,
title_column=title_column,
doc_user_id=user_id
)

return await _text_embed_in_memory(
Expand Down Expand Up @@ -102,6 +104,7 @@ async def _text_embed_with_vector_store(
strategy: dict[str, Any],
vector_store: BaseVectorStore,
vector_store_config: dict,
doc_user_id: str,
id_column: str = "id",
title_column: str | None = None,
):
Expand Down Expand Up @@ -169,6 +172,7 @@ async def _text_embed_with_vector_store(
text=doc_text,
vector=doc_vector,
attributes={"title": doc_title},
user_id=doc_user_id,
)
documents.append(document)

Expand Down
5 changes: 5 additions & 0 deletions graphrag/index/workflows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def run_workflow(
cache=context.cache,
text_embed_config=text_embed,
embedded_fields=embedded_fields,
user_id=config.user_id
)

if config.snapshots.embeddings:
Expand All @@ -79,6 +80,7 @@ async def generate_text_embeddings(
cache: PipelineCache,
text_embed_config: dict,
embedded_fields: set[str],
user_id: str
) -> dict[str, pd.DataFrame]:
"""All the steps to generate all embeddings."""
embedding_param_map = {
Expand Down Expand Up @@ -138,6 +140,7 @@ async def generate_text_embeddings(
callbacks=callbacks,
cache=cache,
text_embed_config=text_embed_config,
user_id=user_id,
**embedding_param_map[field],
)
return outputs
Expand All @@ -150,6 +153,7 @@ async def _run_and_snapshot_embeddings(
callbacks: WorkflowCallbacks,
cache: PipelineCache,
text_embed_config: dict,
user_id: str,
) -> pd.DataFrame:
"""All the steps to generate single embedding."""
data["embedding"] = await embed_text(
Expand All @@ -159,6 +163,7 @@ async def _run_and_snapshot_embeddings(
embed_column=embed_column,
embedding_name=name,
strategy=text_embed_config["strategy"],
user_id=user_id
)

return data.loc[:, ["id", "embedding"]]
13 changes: 10 additions & 3 deletions graphrag/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def connect(self, **kwargs: Any) -> Any:
raise ValueError(not_supported_error)

def load_documents(
self, documents: list[VectorStoreDocument], overwrite: bool = True
self, documents: list[VectorStoreDocument], overwrite: bool = True
) -> None:
"""Load documents into an Azure AI Search index."""
if overwrite:
Expand Down Expand Up @@ -120,6 +120,10 @@ def load_documents(
name="attributes",
type=SearchFieldDataType.String,
),
SimpleField(
name="user_id",
type=SearchFieldDataType.String,
),
],
vector_search=vector_search,
)
Expand All @@ -133,6 +137,7 @@ def load_documents(
"vector": doc.vector,
"text": doc.text,
"attributes": json.dumps(doc.attributes),
"user_id": doc.user_id
}
for doc in documents
if doc.vector is not None
Expand All @@ -158,7 +163,7 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
return self.query_filter

def similarity_search_by_vector(
self, query_embedding: list[float], k: int = 10, **kwargs: Any
self, query_embedding: list[float], k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a vector-based similarity search."""
vectorized_query = VectorizedQuery(
Expand All @@ -176,6 +181,7 @@ def similarity_search_by_vector(
text=doc.get("text", ""),
vector=doc.get("vector", []),
attributes=(json.loads(doc.get("attributes", "{}"))),
user_id=doc.get("user_id", ""),
),
# Cosine similarity between 0.333 and 1.000
# https://learn.microsoft.com/en-us/azure/search/hybrid-search-ranking#scores-in-a-hybrid-search-results
Expand All @@ -185,7 +191,7 @@ def similarity_search_by_vector(
]

def similarity_search_by_text(
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
) -> list[VectorStoreSearchResult]:
"""Perform a text-based similarity search."""
query_embedding = text_embedder(text)
Expand All @@ -203,4 +209,5 @@ def search_by_id(self, id: str) -> VectorStoreDocument:
text=response.get("text", ""),
vector=response.get("vector", []),
attributes=(json.loads(response.get("attributes", "{}"))),
user_id=response.get("user_id", ""),
)
3 changes: 3 additions & 0 deletions graphrag/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class VectorStoreDocument:
id: str | int
"""unique id for the document"""

user_id: str
"""unique user id"""

text: str | None
vector: list[float] | None

Expand Down