diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 96e6111e9e..93813cb391 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -12,7 +12,7 @@ from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.embeddings import create_collection_name +from graphrag.utils.api import get_collection_name from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument from graphrag.vector_stores.factory import VectorStoreFactory @@ -49,7 +49,7 @@ async def embed_text( vector_store_config = strategy.get("vector_store") if vector_store_config: - collection_name = _get_collection_name(vector_store_config, embedding_name) + collection_name = get_collection_name(vector_store_config, embedding_name) vector_store: BaseVectorStore = _create_vector_store( vector_store_config, collection_name ) @@ -197,15 +197,6 @@ def _create_vector_store( return vector_store -def _get_collection_name(vector_store_config: dict, embedding_name: str) -> str: - container_name = vector_store_config.get("container_name", "default") - collection_name = create_collection_name(container_name, embedding_name) - - msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" - logger.info(msg) - return collection_name - - def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: """Load strategy method definition.""" match strategy: diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index 484a279b1f..e5aa1a94aa 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -20,6 +20,8 @@ VectorStoreSearchResult, ) from graphrag.vector_stores.factory import VectorStoreFactory +import logging +logger = logging.getLogger(__name__) class MultiVectorStore(BaseVectorStore): @@ -103,9 +105,7 @@ def get_embedding_store( index_names = [] for index, store in config_args.items(): vector_store_type = store["type"] - collection_name = create_collection_name( - store.get("container_name", "default"), embedding_name - ) + collection_name = get_collection_name(store, embedding_name) embedding_store = VectorStoreFactory().create_vector_store( vector_store_type=vector_store_type, kwargs={**store, "collection_name": collection_name}, @@ -119,6 +119,21 @@ def get_embedding_store( return MultiVectorStore(embedding_stores, index_names) +def get_collection_name(vector_store_config: dict, embedding_name: str) -> str: + collection_name = vector_store_config.get("collection_name") + if collection_name: + msg = f"using vector store {vector_store_config.get('type')} with user provided collection_name {collection_name} for embedding {embedding_name}" + logger.info(msg) + return collection_name + + container_name = vector_store_config.get("container_name", "default") + collection_name = create_collection_name(container_name, embedding_name) + + msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {collection_name}" + logger.info(msg) + return collection_name + + def reformat_context_data(context_data: dict) -> dict: """ Reformats context_data for all query responses.