diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 0b4e174bc2..5645e548de 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -27,7 +27,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] + vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant", "remote::mongodb"] python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} fail-fast: false # we want to run all tests regardless of failure @@ -97,6 +97,16 @@ jobs: -p 6333:6333 \ qdrant/qdrant + - name: Setup MongoDB + if: matrix.vector-io-provider == 'remote::mongodb' + run: | + docker run --rm -d --pull always \ + --name mongodb \ + -p 27017:27017 \ + -e MONGO_INITDB_ROOT_USERNAME=llamastack \ + -e MONGO_INITDB_ROOT_PASSWORD=llamastack \ + mongodb/mongodb-atlas-local:latest + - name: Wait for Qdrant to be ready if: matrix.vector-io-provider == 'remote::qdrant' run: | @@ -112,6 +122,21 @@ jobs: docker logs qdrant exit 1 + - name: Wait for MongoDB to be ready + if: matrix.vector-io-provider == 'remote::mongodb' + run: | + echo "Waiting for MongoDB to be ready..." + for i in {1..30}; do + if docker exec mongodb mongosh --quiet --eval "db.adminCommand('ping').ok" > /dev/null 2>&1; then + echo "MongoDB is ready!" + exit 0 + fi + sleep 2 + done + echo "MongoDB failed to start" + docker logs mongodb + exit 1 + - name: Wait for ChromaDB to be ready if: matrix.vector-io-provider == 'remote::chromadb' run: | @@ -166,6 +191,11 @@ jobs: QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }} ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} + ENABLE_MONGODB: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'true' || '' }} + MONGODB_HOST: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'localhost' || '' }} + MONGODB_PORT: ${{ matrix.vector-io-provider == 'remote::mongodb' && '27017' || '' }} + MONGODB_USERNAME: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }} + MONGODB_PASSWORD: ${{ matrix.vector-io-provider == 'remote::mongodb' && 'llamastack' || '' }} run: | uv run --no-sync \ pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ @@ -192,6 +222,11 @@ jobs: run: | docker logs qdrant > qdrant.log + - name: Write MongoDB logs to file + if: ${{ always() && matrix.vector-io-provider == 'remote::mongodb' }} + run: | + docker logs mongodb > mongodb.log + - name: Upload all logs to artifacts if: ${{ always() }} uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 diff --git a/docs/docs/providers/vector_io/remote_mongodb.mdx b/docs/docs/providers/vector_io/remote_mongodb.mdx new file mode 100644 index 0000000000..8564e1a5e4 --- /dev/null +++ b/docs/docs/providers/vector_io/remote_mongodb.mdx @@ -0,0 +1,268 @@ +--- +description: | + [MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It + uses MongoDB Atlas Vector Search to store and query vectors in the cloud. + That means you get enterprise-grade vector search with MongoDB's scalability and reliability. + + ## Features + + - Cloud-native vector search with MongoDB Atlas + - Fully integrated with Llama Stack + - Enterprise-grade security and scalability + - Supports multiple search modes: vector, keyword, and hybrid search + - Built-in metadata filtering and text search capabilities + - Automatic index management + + ## Search Modes + + MongoDB Atlas Vector Search supports three different search modes: + + ### Vector Search + Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors. + + ```python + # Vector search example + search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, + ) + ``` + + ### Keyword Search + Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms. + + ```python + # Keyword search example + search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, + ) + ``` + + ### Hybrid Search + Hybrid search combines both vector and keyword search methods using configurable reranking algorithms. + + ```python + # Hybrid search with RRF ranker (default) + search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ) + + # Hybrid search with weighted ranker + search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, + ) + ``` + + ## Usage + + To use MongoDB Atlas in your Llama Stack project, follow these steps: + + 1. Create a MongoDB Atlas cluster with Vector Search enabled + 2. Install the necessary dependencies + 3. Configure your Llama Stack project to use MongoDB + 4. Start storing and querying vectors + + ## Configuration + + ### Environment Variables + Set up the following environment variable for your MongoDB Atlas connection: + + ```bash + export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack" + ``` + + ### Configuration Example + + ```yaml + vector_io: + - provider_id: mongodb_atlas + provider_type: remote::mongodb + config: + connection_string: "${env.MONGODB_CONNECTION_STRING}" + database_name: "llama_stack" + index_name: "vector_index" + similarity_metric: "cosine" + ``` + + ## Installation + + You can install the MongoDB Python driver using pip: + + ```bash + pip install pymongo + ``` + + ## Documentation + + See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search. + + For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/). +sidebar_label: Remote - Mongodb +title: remote::mongodb +--- + +# remote::mongodb + +## Description + + +[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It +uses MongoDB Atlas Vector Search to store and query vectors in the cloud. +That means you get enterprise-grade vector search with MongoDB's scalability and reliability. + +## Features + +- Cloud-native vector search with MongoDB Atlas +- Fully integrated with Llama Stack +- Enterprise-grade security and scalability +- Supports multiple search modes: vector, keyword, and hybrid search +- Built-in metadata filtering and text search capabilities +- Automatic index management + +## Search Modes + +MongoDB Atlas Vector Search supports three different search modes: + +### Vector Search +Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors. + +```python +# Vector search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, +) +``` + +### Keyword Search +Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms. + +```python +# Keyword search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, +) +``` + +### Hybrid Search +Hybrid search combines both vector and keyword search methods using configurable reranking algorithms. + +```python +# Hybrid search with RRF ranker (default) +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, +) + +# Hybrid search with weighted ranker +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, +) +``` + +## Usage + +To use MongoDB Atlas in your Llama Stack project, follow these steps: + +1. Create a MongoDB Atlas cluster with Vector Search enabled +2. Install the necessary dependencies +3. Configure your Llama Stack project to use MongoDB +4. Start storing and querying vectors + +## Configuration + +### Environment Variables +Set up the following environment variable for your MongoDB Atlas connection: + +```bash +export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack" +``` + +### Configuration Example + +```yaml +vector_io: + - provider_id: mongodb_atlas + provider_type: remote::mongodb + config: + connection_string: "${env.MONGODB_CONNECTION_STRING}" + database_name: "llama_stack" + index_name: "vector_index" + similarity_metric: "cosine" +``` + +## Installation + +You can install the MongoDB Python driver using pip: + +```bash +pip install pymongo +``` + +## Documentation + +See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search. + +For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/). + + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `connection_string` | `str \| None` | No | | MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/) | +| `database_name` | `` | No | llama_stack | Database name to use for vector collections | +| `index_name` | `` | No | vector_index | Name of the vector search index | +| `path_field` | `` | No | embedding | Field name for storing embeddings | +| `similarity_metric` | `` | No | cosine | Similarity metric: cosine, euclidean, or dotProduct | +| `max_pool_size` | `` | No | 100 | Maximum connection pool size | +| `timeout_ms` | `` | No | 30000 | Connection timeout in milliseconds | +| `persistence` | `llama_stack.core.storage.datatypes.KVStoreReference \| None` | No | | Config for KV store backend for metadata storage | + +## Sample Configuration + +```yaml +connection_string: ${env.MONGODB_CONNECTION_STRING:=} +database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack} +index_name: ${env.MONGODB_INDEX_NAME:=vector_index} +path_field: ${env.MONGODB_PATH_FIELD:=embedding} +similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine} +max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100} +timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000} +persistence: + namespace: vector_io::mongodb_atlas + backend: kv_default +``` diff --git a/llama_stack/providers/remote/vector_io/mongodb/__init__.py b/llama_stack/providers/remote/vector_io/mongodb/__init__.py new file mode 100644 index 0000000000..d209fa3e28 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/mongodb/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MongoDBVectorIOConfig + + +async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]): + from .mongodb import MongoDBVectorIOAdapter + + # Handle the deps resolution - if files API exists, pass it, otherwise None + files_api = deps.get(Api.files) + models_api = deps.get(Api.models) + impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/mongodb/config.py b/llama_stack/providers/remote/vector_io/mongodb/config.py new file mode 100644 index 0000000000..ee6afa80ec --- /dev/null +++ b/llama_stack/providers/remote/vector_io/mongodb/config.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class MongoDBVectorIOConfig(BaseModel): + """Configuration for MongoDB Atlas Vector Search provider. + + This provider connects to MongoDB Atlas and uses Vector Search for RAG operations. + """ + + # MongoDB Atlas connection details + connection_string: str | None = Field( + default=None, + description="MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)", + ) + database_name: str = Field(default="llama_stack", description="Database name to use for vector collections") + + # Vector search configuration + index_name: str = Field(default="vector_index", description="Name of the vector search index") + path_field: str = Field(default="embedding", description="Field name for storing embeddings") + similarity_metric: str = Field( + default="cosine", + description="Similarity metric: cosine, euclidean, or dotProduct", + ) + + # Connection options + max_pool_size: int = Field(default=100, description="Maximum connection pool size") + timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds") + + # KV store configuration + kvstore: KVStoreConfig = Field(description="Config for KV store backend for metadata storage") + + @classmethod + def sample_run_config( + cls, + __distro_dir__: str, + connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}", + database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}", + **kwargs: Any, + ) -> dict[str, Any]: + return { + "connection_string": connection_string, + "database_name": database_name, + "index_name": "${env.MONGODB_INDEX_NAME:=vector_index}", + "path_field": "${env.MONGODB_PATH_FIELD:=embedding}", + "similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}", + "max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}", + "timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="mongodb_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/mongodb/mongodb.py b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py new file mode 100644 index 0000000000..291f20237f --- /dev/null +++ b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py @@ -0,0 +1,606 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import heapq +import time +from typing import Any + +from numpy.typing import NDArray +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.database import Database +from pymongo.operations import SearchIndexModel +from pymongo.server_api import ServerApi + +from llama_stack.apis.common.errors import VectorStoreNotFoundError +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + VectorDBsProtocolPrivate, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( + OpenAIVectorStoreMixin, +) +from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, + EmbeddingIndex, + VectorDBWithIndex, +) +from llama_stack.providers.utils.vector_io.vector_utils import ( + WeightedInMemoryAggregator, + sanitize_collection_name, +) + +from .config import MongoDBVectorIOConfig + +logger = get_logger(name=__name__, category="vector_io::mongodb") + +VERSION = "v1" +VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::" + + +class MongoDBIndex(EmbeddingIndex): + """MongoDB Atlas Vector Search index implementation optimized for RAG.""" + + def __init__( + self, + vector_db: VectorDB, + collection: Collection, + config: MongoDBVectorIOConfig, + ): + self.vector_db = vector_db + self.collection = collection + self.config = config + self.dimension = vector_db.embedding_dimension + + async def initialize(self) -> None: + """Initialize the MongoDB collection and ensure vector search index exists.""" + try: + # Create the collection if it doesn't exist + collection_names = self.collection.database.list_collection_names() + if self.collection.name not in collection_names: + logger.info(f"Creating collection '{self.collection.name}'") + # Create collection by inserting a dummy document + dummy_doc = {"_id": "__dummy__", "dummy": True} + self.collection.insert_one(dummy_doc) + # Remove the dummy document + self.collection.delete_one({"_id": "__dummy__"}) + logger.info(f"Collection '{self.collection.name}' created successfully") + + # Create optimized vector search index for RAG + await self._create_vector_search_index() + + # Create text index for hybrid search + await self._ensure_text_index() + + except Exception as e: + logger.exception( + f"Failed to initialize MongoDB index for vector_db: {self.vector_db.identifier}. " + f"Collection name: {self.collection.name}. Error: {str(e)}" + ) + raise RuntimeError( + f"Failed to initialize MongoDB vector search index. " + f"Vector store '{self.vector_db.identifier}' cannot function without indexes. " + f"Error: {str(e)}" + ) from e + + async def _create_vector_search_index(self) -> None: + """Create optimized vector search index based on MongoDB RAG best practices.""" + try: + # Check if vector search index exists + indexes = list(self.collection.list_search_indexes()) + index_exists = any(idx.get("name") == self.config.index_name for idx in indexes) + + if not index_exists: + # Create vector search index optimized for RAG + # Based on MongoDB's RAG example using new vectorSearch format + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "numDimensions": self.dimension, + "path": self.config.path_field, + "similarity": self._convert_similarity_metric(self.config.similarity_metric), + } + ] + }, + name=self.config.index_name, + type="vectorSearch", + ) + + logger.info( + f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'" + ) + + self.collection.create_search_index(model=search_index_model) + + # Wait for index to be ready (like in MongoDB RAG example) + await self._wait_for_index_ready() + + logger.info("Vector search index created and ready for RAG queries") + + except Exception as e: + logger.warning(f"Failed to create vector search index: {e}") + + def _convert_similarity_metric(self, metric: str) -> str: + """Convert internal similarity metric to MongoDB Atlas format.""" + metric_map = { + "cosine": "cosine", + "euclidean": "euclidean", + "dotProduct": "dotProduct", + "dot_product": "dotProduct", + } + return metric_map.get(metric, "cosine") + + async def _wait_for_index_ready(self) -> None: + """Wait for the vector search index to be ready, based on MongoDB RAG example.""" + logger.info("Waiting for vector search index to be ready...") + + max_wait_time = 300 # 5 minutes max wait + wait_interval = 5 + elapsed_time = 0 + + while elapsed_time < max_wait_time: + try: + indices = list(self.collection.list_search_indexes(self.config.index_name)) + if len(indices) and indices[0].get("queryable") is True: + logger.info(f"Vector search index '{self.config.index_name}' is ready for querying") + return + except Exception: + pass + + time.sleep(wait_interval) + elapsed_time += wait_interval + + logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s") + + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None: + """Add chunks with embeddings to MongoDB collection optimized for RAG.""" + if len(chunks) != len(embeddings): + raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}") + + documents = [] + for i, chunk in enumerate(chunks): + # Structure document for optimal RAG retrieval + doc = { + "_id": chunk.chunk_id, + "chunk_id": chunk.chunk_id, + "text": interleaved_content_as_str(chunk.content), # Key field for RAG context + "content": interleaved_content_as_str(chunk.content), # Backward compatibility + "metadata": chunk.metadata or {}, + "chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}), + self.config.path_field: embeddings[i].tolist(), # Vector embeddings + "document": chunk.model_dump(), # Full chunk data + } + documents.append(doc) + + try: + # Use upsert behavior for chunks + for doc in documents: + self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True) + + logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection") + except Exception as e: + logger.exception(f"Failed to add chunks to MongoDB collection: {e}") + raise + + async def query_vector( + self, + embedding: NDArray, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """Perform vector similarity search optimized for RAG based on MongoDB example.""" + try: + # Use MongoDB's vector search aggregation pipeline optimized for RAG + pipeline = [ + { + "$vectorSearch": { + "index": self.config.index_name, + "queryVector": embedding.tolist(), + "path": self.config.path_field, + "numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates + "limit": k, + } + }, + { + "$project": { + "_id": 0, + "text": 1, # Primary field for RAG context + "content": 1, # Backward compatibility + "metadata": 1, + "chunk_metadata": 1, + "document": 1, + "score": {"$meta": "vectorSearchScore"}, + } + }, + {"$match": {"score": {"$gte": score_threshold}}}, + ] + + results = list(self.collection.aggregate(pipeline)) + + chunks = [] + scores = [] + for result in results: + score = result.get("score", 0.0) + if score >= score_threshold: + chunk_data = result.get("document", {}) + if chunk_data: + chunks.append(Chunk(**chunk_data)) + scores.append(float(score)) + + logger.debug(f"Vector search for RAG returned {len(chunks)} results") + return QueryChunksResponse(chunks=chunks, scores=scores) + + except Exception as e: + logger.exception(f"Vector search for RAG failed: {e}") + raise RuntimeError(f"Vector search for RAG failed: {e}") from e + + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """Perform text search using MongoDB's text search for RAG context retrieval.""" + try: + # Ensure text index exists + await self._ensure_text_index() + + pipeline = [ + {"$match": {"$text": {"$search": query_string}}}, + { + "$project": { + "_id": 0, + "text": 1, # Primary field for RAG context + "content": 1, # Backward compatibility + "metadata": 1, + "chunk_metadata": 1, + "document": 1, + "score": {"$meta": "textScore"}, + } + }, + {"$match": {"score": {"$gte": score_threshold}}}, + {"$sort": {"score": {"$meta": "textScore"}}}, + {"$limit": k}, + ] + + results = list(self.collection.aggregate(pipeline)) + + chunks = [] + scores = [] + for result in results: + score = result.get("score", 0.0) + if score >= score_threshold: + chunk_data = result.get("document", {}) + if chunk_data: + chunks.append(Chunk(**chunk_data)) + scores.append(float(score)) + + logger.debug(f"Keyword search for RAG returned {len(chunks)} results") + return QueryChunksResponse(chunks=chunks, scores=scores) + + except Exception as e: + logger.exception(f"Keyword search for RAG failed: {e}") + raise RuntimeError(f"Keyword search for RAG failed: {e}") from e + + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + reranker_type: str, + reranker_params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + """Perform hybrid search for enhanced RAG context retrieval.""" + if reranker_params is None: + reranker_params = {} + + # Get results from both search methods + vector_response = await self.query_vector(embedding, k, 0.0) + keyword_response = await self.query_keyword(query_string, k, 0.0) + + # Convert responses to score dictionaries + vector_scores = { + chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) + } + keyword_scores = { + chunk.chunk_id: score + for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) + } + + # Combine scores using the reranking utility + combined_scores = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type, reranker_params + ) + + # Get top-k results + top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1]) + + # Filter by score threshold + filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold] + + # Create chunk map + chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks} + + # Build final results + chunks = [] + scores = [] + for doc_id, score in filtered_items: + if doc_id in chunk_map: + chunks.append(chunk_map[doc_id]) + scores.append(score) + + logger.debug(f"Hybrid search for RAG returned {len(chunks)} results") + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from MongoDB collection.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] + try: + result = self.collection.delete_many({"_id": {"$in": chunk_ids}}) + logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection") + except Exception as e: + logger.exception(f"Failed to delete chunks: {e}") + raise + + async def delete(self) -> None: + """Delete the entire collection.""" + try: + self.collection.drop() + logger.debug(f"Dropped MongoDB collection: {self.collection.name}") + except Exception as e: + logger.exception(f"Failed to drop collection: {e}") + raise + + async def _ensure_text_index(self) -> None: + """Ensure text search index exists on content fields for RAG.""" + try: + indexes = list(self.collection.list_indexes()) + text_index_exists = any( + any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys()) + and idx.get("textIndexVersion") is not None + for idx in indexes + ) + + if not text_index_exists: + logger.info("Creating text search index on content fields for RAG") + # Index both 'text' and 'content' fields for comprehensive text search + self.collection.create_index([("text", "text"), ("content", "text")]) + logger.info("Text search index created successfully for RAG") + + except Exception as e: + logger.warning(f"Failed to create text index for RAG: {e}") + + +class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): + """MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows.""" + + def __init__( + self, + config: MongoDBVectorIOConfig, + inference_api, + files_api=None, + models_api=None, + ) -> None: + # Handle the case where files_api might be a ProviderSpec that needs resolution + resolved_files_api = files_api + super().__init__(files_api=resolved_files_api, kvstore=None) + self.config = config + self.inference_api = inference_api + self.models_api = models_api + self.client: MongoClient | None = None + self.database: Database | None = None + self.cache: dict[str, VectorDBWithIndex] = {} + self.kvstore: KVStore | None = None + + async def initialize(self) -> None: + """Initialize MongoDB connection optimized for RAG workflows.""" + logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG") + + try: + # Initialize KV store for metadata + self.kvstore = await kvstore_impl(self.config.kvstore) + + # Validate connection string + if not self.config.connection_string: + raise ValueError( + "MongoDB connection_string is required but not provided. " + "Please set MONGODB_CONNECTION_STRING environment variable or provide it in config." + ) + + # Connect to MongoDB with optimized settings for RAG + self.client = MongoClient( + self.config.connection_string, + server_api=ServerApi("1"), + maxPoolSize=self.config.max_pool_size, + serverSelectionTimeoutMS=self.config.timeout_ms, + # Additional settings for RAG performance + retryWrites=True, + readPreference="primaryPreferred", + ) + + # Test connection + self.client.admin.command("ping") + logger.info("Successfully connected to MongoDB Atlas for RAG") + + # Get database + self.database = self.client[self.config.database_name] + + # Initialize OpenAI vector stores + await self.initialize_openai_vector_stores() + + # Load existing vector databases + await self._load_existing_vector_dbs() + + logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully") + + except Exception as e: + logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") + raise RuntimeError("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") from e + + async def shutdown(self) -> None: + """Shutdown MongoDB connection.""" + if self.client: + self.client.close() + logger.info("MongoDB Atlas RAG connection closed") + + async def health(self) -> HealthResponse: + """Perform health check on MongoDB connection.""" + try: + if self.client: + self.client.admin.command("ping") + return HealthResponse(status=HealthStatus.OK) + else: + return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized") + except Exception as e: + return HealthResponse( + status=HealthStatus.ERROR, + message=f"MongoDB RAG health check failed: {str(e)}", + ) + + async def register_vector_db(self, vector_db: VectorDB) -> None: + """Register a new vector database optimized for RAG.""" + if self.database is None: + raise RuntimeError("MongoDB database not initialized") + + # Create collection name from vector DB identifier + collection_name = sanitize_collection_name(vector_db.identifier) + collection = self.database[collection_name] + + # Create and initialize MongoDB index optimized for RAG + mongodb_index = MongoDBIndex(vector_db, collection, self.config) + await mongodb_index.initialize() + + # Create vector DB with index wrapper + vector_db_with_index = VectorDBWithIndex( + vector_db=vector_db, + index=mongodb_index, + inference_api=self.inference_api, + ) + + # Cache the vector DB + self.cache[vector_db.identifier] = vector_db_with_index + + # Save vector database info to KVStore for persistence + if self.kvstore: + await self.kvstore.set( + f"{VECTOR_DBS_PREFIX}{vector_db.identifier}", + vector_db.model_dump_json(), + ) + + logger.info(f"Registered vector database for RAG: {vector_db.identifier}") + + async def unregister_vector_db(self, vector_db_id: str) -> None: + """Unregister a vector database.""" + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + # Clean up from KV store + if self.kvstore: + await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}") + + logger.info(f"Unregistered vector database: {vector_db_id}") + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + """Insert chunks into the vector database optimized for RAG.""" + vector_db_with_index = await self._get_vector_db_index(vector_db_id) + await vector_db_with_index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + """Query chunks from the vector database optimized for RAG context retrieval.""" + vector_db_with_index = await self._get_vector_db_index(vector_db_id) + return await vector_db_with_index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from the vector database.""" + vector_db_with_index = await self._get_vector_db_index(store_id) + await vector_db_with_index.index.delete_chunks(chunks_for_deletion) + + async def _get_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + """Get vector database index from cache.""" + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + raise VectorStoreNotFoundError(vector_db_id) + + async def _load_existing_vector_dbs(self) -> None: + """Load existing vector databases from KVStore.""" + if not self.kvstore: + return + + try: + # Use keys_in_range to get all vector database keys from KVStore + # This searches for keys with the prefix by using range scan + start_key = VECTOR_DBS_PREFIX + # Create an end key by incrementing the last character + end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1) + + vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key) + + for key in vector_db_keys: + try: + vector_db_data = await self.kvstore.get(key) + if vector_db_data: + import json + + vector_db = VectorDB(**json.loads(vector_db_data)) + # Register the vector database without re-initializing + await self._register_existing_vector_db(vector_db) + logger.info(f"Loaded existing RAG-optimized vector database: {vector_db.identifier}") + except Exception as e: + logger.warning(f"Failed to load vector database from key {key}: {e}") + continue + + except Exception as e: + logger.warning(f"Failed to load existing vector databases: {e}") + + async def _register_existing_vector_db(self, vector_db: VectorDB) -> None: + """Register an existing vector database without re-initialization.""" + if self.database is None: + raise RuntimeError("MongoDB database not initialized") + + # Create collection name from vector DB identifier + collection_name = sanitize_collection_name(vector_db.identifier) + collection = self.database[collection_name] + + # Create MongoDB index without initialization (collection already exists) + mongodb_index = MongoDBIndex(vector_db, collection, self.config) + + # Create vector DB with index wrapper + vector_db_with_index = VectorDBWithIndex( + vector_db=vector_db, + index=mongodb_index, + inference_api=self.inference_api, + ) + + # Cache the vector DB + self.cache[vector_db.identifier] = vector_db_with_index diff --git a/src/llama_stack/distributions/ci-tests/build.yaml b/src/llama_stack/distributions/ci-tests/build.yaml index c01e415a96..5e52f9e250 100644 --- a/src/llama_stack/distributions/ci-tests/build.yaml +++ b/src/llama_stack/distributions/ci-tests/build.yaml @@ -25,6 +25,7 @@ distribution_spec: - provider_type: inline::milvus - provider_type: remote::chromadb - provider_type: remote::pgvector + - provider_type: remote::mongodb - provider_type: remote::qdrant - provider_type: remote::weaviate files: diff --git a/src/llama_stack/distributions/ci-tests/run.yaml b/src/llama_stack/distributions/ci-tests/run.yaml index 702acff8e5..d7fc3dbce4 100644 --- a/src/llama_stack/distributions/ci-tests/run.yaml +++ b/src/llama_stack/distributions/ci-tests/run.yaml @@ -128,6 +128,19 @@ providers: persistence: namespace: vector_io::pgvector backend: kv_default + - provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas} + provider_type: remote::mongodb + config: + connection_string: ${env.MONGODB_CONNECTION_STRING:=} + database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack} + index_name: ${env.MONGODB_INDEX_NAME:=vector_index} + path_field: ${env.MONGODB_PATH_FIELD:=embedding} + similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine} + max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100} + timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000} + persistence: + namespace: vector_io::mongodb_atlas + backend: kv_default - provider_id: ${env.QDRANT_URL:+qdrant} provider_type: remote::qdrant config: diff --git a/src/llama_stack/distributions/starter-gpu/build.yaml b/src/llama_stack/distributions/starter-gpu/build.yaml index b2e2a0c859..2fc44ec9b9 100644 --- a/src/llama_stack/distributions/starter-gpu/build.yaml +++ b/src/llama_stack/distributions/starter-gpu/build.yaml @@ -26,6 +26,7 @@ distribution_spec: - provider_type: inline::milvus - provider_type: remote::chromadb - provider_type: remote::pgvector + - provider_type: remote::mongodb - provider_type: remote::qdrant - provider_type: remote::weaviate files: diff --git a/src/llama_stack/distributions/starter-gpu/run.yaml b/src/llama_stack/distributions/starter-gpu/run.yaml index 807f0d6788..eb8f481fe9 100644 --- a/src/llama_stack/distributions/starter-gpu/run.yaml +++ b/src/llama_stack/distributions/starter-gpu/run.yaml @@ -128,6 +128,19 @@ providers: persistence: namespace: vector_io::pgvector backend: kv_default + - provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas} + provider_type: remote::mongodb + config: + connection_string: ${env.MONGODB_CONNECTION_STRING:=} + database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack} + index_name: ${env.MONGODB_INDEX_NAME:=vector_index} + path_field: ${env.MONGODB_PATH_FIELD:=embedding} + similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine} + max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100} + timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000} + persistence: + namespace: vector_io::mongodb_atlas + backend: kv_default - provider_id: ${env.QDRANT_URL:+qdrant} provider_type: remote::qdrant config: diff --git a/src/llama_stack/distributions/starter/build.yaml b/src/llama_stack/distributions/starter/build.yaml index baa80ef3e8..354dbfbb07 100644 --- a/src/llama_stack/distributions/starter/build.yaml +++ b/src/llama_stack/distributions/starter/build.yaml @@ -26,6 +26,7 @@ distribution_spec: - provider_type: inline::milvus - provider_type: remote::chromadb - provider_type: remote::pgvector + - provider_type: remote::mongodb - provider_type: remote::qdrant - provider_type: remote::weaviate files: diff --git a/src/llama_stack/distributions/starter/run.yaml b/src/llama_stack/distributions/starter/run.yaml index eb4652af02..c992ec164a 100644 --- a/src/llama_stack/distributions/starter/run.yaml +++ b/src/llama_stack/distributions/starter/run.yaml @@ -128,6 +128,19 @@ providers: persistence: namespace: vector_io::pgvector backend: kv_default + - provider_id: ${env.MONGODB_CONNECTION_STRING:+mongodb_atlas} + provider_type: remote::mongodb + config: + connection_string: ${env.MONGODB_CONNECTION_STRING:=} + database_name: ${env.MONGODB_DATABASE_NAME:=llama_stack} + index_name: ${env.MONGODB_INDEX_NAME:=vector_index} + path_field: ${env.MONGODB_PATH_FIELD:=embedding} + similarity_metric: ${env.MONGODB_SIMILARITY_METRIC:=cosine} + max_pool_size: ${env.MONGODB_MAX_POOL_SIZE:=100} + timeout_ms: ${env.MONGODB_TIMEOUT_MS:=30000} + persistence: + namespace: vector_io::mongodb_atlas + backend: kv_default - provider_id: ${env.QDRANT_URL:+qdrant} provider_type: remote::qdrant config: diff --git a/src/llama_stack/distributions/starter/starter.py b/src/llama_stack/distributions/starter/starter.py index 49b7a24631..d635607c44 100644 --- a/src/llama_stack/distributions/starter/starter.py +++ b/src/llama_stack/distributions/starter/starter.py @@ -31,11 +31,14 @@ ) from llama_stack.providers.registry.inference import available_providers from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.mongodb.config import MongoDBVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, ) from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig -from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig +from llama_stack.providers.remote.vector_io.weaviate.config import ( + WeaviateVectorIOConfig, +) from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig @@ -118,6 +121,7 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: BuildProvider(provider_type="inline::milvus"), BuildProvider(provider_type="remote::chromadb"), BuildProvider(provider_type="remote::pgvector"), + BuildProvider(provider_type="remote::mongodb"), BuildProvider(provider_type="remote::qdrant"), BuildProvider(provider_type="remote::weaviate"), ], @@ -228,6 +232,15 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: password="${env.PGVECTOR_PASSWORD:=}", ), ), + Provider( + provider_id="${env.MONGODB_CONNECTION_STRING:+mongodb_atlas}", + provider_type="remote::mongodb", + config=MongoDBVectorIOConfig.sample_run_config( + f"~/.llama/distributions/{name}", + connection_string="${env.MONGODB_CONNECTION_STRING:=}", + database_name="${env.MONGODB_DATABASE_NAME:=llama_stack}", + ), + ), Provider( provider_id="${env.QDRANT_URL:+qdrant}", provider_type="remote::qdrant", @@ -327,5 +340,13 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate: "azure", "Azure API Type", ), + "MONGODB_CONNECTION_STRING": ( + "", + "MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)", + ), + "MONGODB_DATABASE_NAME": ( + "llama_stack", + "MongoDB database name", + ), }, ) diff --git a/src/llama_stack/providers/registry/vector_io.py b/src/llama_stack/providers/registry/vector_io.py index 55b3027514..25f8484adb 100644 --- a/src/llama_stack/providers/registry/vector_io.py +++ b/src/llama_stack/providers/registry/vector_io.py @@ -823,6 +823,132 @@ def available_providers() -> list[ProviderSpec]: optional_api_dependencies=[Api.files, Api.models], description=""" Please refer to the remote provider documentation. +""", + ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="mongodb", + provider_type="remote::mongodb", + pip_packages=["pymongo>=4.0.0"], + module="llama_stack.providers.remote.vector_io.mongodb", + config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" +[MongoDB Atlas](https://www.mongodb.com/products/platform/atlas-vector-search) is a remote vector database provider for Llama Stack. It +uses MongoDB Atlas Vector Search to store and query vectors in the cloud. +That means you get enterprise-grade vector search with MongoDB's scalability and reliability. + +## Features + +- Cloud-native vector search with MongoDB Atlas +- Fully integrated with Llama Stack +- Enterprise-grade security and scalability +- Supports multiple search modes: vector, keyword, and hybrid search +- Built-in metadata filtering and text search capabilities +- Automatic index management + +## Search Modes + +MongoDB Atlas Vector Search supports three different search modes: + +### Vector Search +Vector search uses MongoDB's `$vectorSearch` aggregation stage to perform semantic similarity search using embedding vectors. + +```python +# Vector search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, +) +``` + +### Keyword Search +Keyword search uses MongoDB's text search capabilities with full-text indexes to find chunks containing specific terms. + +```python +# Keyword search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, +) +``` + +### Hybrid Search +Hybrid search combines both vector and keyword search methods using configurable reranking algorithms. + +```python +# Hybrid search with RRF ranker (default) +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, +) + +# Hybrid search with weighted ranker +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, +) +``` + +## Usage + +To use MongoDB Atlas in your Llama Stack project, follow these steps: + +1. Create a MongoDB Atlas cluster with Vector Search enabled +2. Install the necessary dependencies +3. Configure your Llama Stack project to use MongoDB +4. Start storing and querying vectors + +## Configuration + +### Environment Variables +Set up the following environment variable for your MongoDB Atlas connection: + +```bash +export MONGODB_CONNECTION_STRING="mongodb+srv://username:password@cluster.mongodb.net/?retryWrites=true&w=majority&appName=llama-stack" +``` + +### Configuration Example + +```yaml +vector_io: + - provider_id: mongodb_atlas + provider_type: remote::mongodb + config: + connection_string: "${env.MONGODB_CONNECTION_STRING}" + database_name: "llama_stack" + index_name: "vector_index" + similarity_metric: "cosine" +``` + +## Installation + +You can install the MongoDB Python driver using pip: + +```bash +pip install pymongo +``` + +## Documentation + +See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/) for more details about MongoDB Atlas Vector Search. + +For general MongoDB documentation, visit [MongoDB Documentation](https://docs.mongodb.com/). """, ), ] diff --git a/src/llama_stack/providers/remote/vector_io/mongodb/__init__.py b/src/llama_stack/providers/remote/vector_io/mongodb/__init__.py new file mode 100644 index 0000000000..d209fa3e28 --- /dev/null +++ b/src/llama_stack/providers/remote/vector_io/mongodb/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MongoDBVectorIOConfig + + +async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: dict[Api, ProviderSpec]): + from .mongodb import MongoDBVectorIOAdapter + + # Handle the deps resolution - if files API exists, pass it, otherwise None + files_api = deps.get(Api.files) + models_api = deps.get(Api.models) + impl = MongoDBVectorIOAdapter(config, deps[Api.inference], files_api, models_api) + await impl.initialize() + return impl diff --git a/src/llama_stack/providers/remote/vector_io/mongodb/config.py b/src/llama_stack/providers/remote/vector_io/mongodb/config.py new file mode 100644 index 0000000000..c94bf70e66 --- /dev/null +++ b/src/llama_stack/providers/remote/vector_io/mongodb/config.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.core.storage.datatypes import KVStoreReference +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class MongoDBVectorIOConfig(BaseModel): + """Configuration for MongoDB Atlas Vector Search provider. + + This provider connects to MongoDB Atlas and uses Vector Search for RAG operations. + """ + + # MongoDB Atlas connection details + connection_string: str | None = Field( + default=None, + description="MongoDB Atlas connection string (e.g., mongodb+srv://user:pass@cluster.mongodb.net/)", + ) + database_name: str = Field(default="llama_stack", description="Database name to use for vector collections") + + # Vector search configuration + index_name: str = Field(default="vector_index", description="Name of the vector search index") + path_field: str = Field(default="embedding", description="Field name for storing embeddings") + similarity_metric: str = Field( + default="cosine", + description="Similarity metric: cosine, euclidean, or dotProduct", + ) + + # Connection options + max_pool_size: int = Field(default=100, description="Maximum connection pool size") + timeout_ms: int = Field(default=30000, description="Connection timeout in milliseconds") + + # KV store configuration + persistence: KVStoreReference | None = Field( + description="Config for KV store backend for metadata storage", default=None + ) + + @classmethod + def sample_run_config( + cls, + __distro_dir__: str, + connection_string: str = "${env.MONGODB_CONNECTION_STRING:=}", + database_name: str = "${env.MONGODB_DATABASE_NAME:=llama_stack}", + **kwargs: Any, + ) -> dict[str, Any]: + return { + "connection_string": connection_string, + "database_name": database_name, + "index_name": "${env.MONGODB_INDEX_NAME:=vector_index}", + "path_field": "${env.MONGODB_PATH_FIELD:=embedding}", + "similarity_metric": "${env.MONGODB_SIMILARITY_METRIC:=cosine}", + "max_pool_size": "${env.MONGODB_MAX_POOL_SIZE:=100}", + "timeout_ms": "${env.MONGODB_TIMEOUT_MS:=30000}", + "persistence": KVStoreReference( + backend="kv_default", + namespace="vector_io::mongodb_atlas", + ).model_dump(exclude_none=True), + } diff --git a/src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py b/src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py new file mode 100644 index 0000000000..b1732008dc --- /dev/null +++ b/src/llama_stack/providers/remote/vector_io/mongodb/mongodb.py @@ -0,0 +1,609 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import heapq +import time +from typing import Any + +from numpy.typing import NDArray +from pymongo import MongoClient +from pymongo.collection import Collection +from pymongo.database import Database +from pymongo.operations import SearchIndexModel +from pymongo.server_api import ServerApi + +from llama_stack.apis.common.errors import VectorStoreNotFoundError +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.apis.vector_stores import VectorStore +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import ( + HealthResponse, + HealthStatus, + VectorStoresProtocolPrivate, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, +) +from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( + OpenAIVectorStoreMixin, +) +from llama_stack.providers.utils.memory.vector_store import ( + ChunkForDeletion, + EmbeddingIndex, + VectorStoreWithIndex, +) +from llama_stack.providers.utils.vector_io.vector_utils import ( + WeightedInMemoryAggregator, + sanitize_collection_name, +) + +from .config import MongoDBVectorIOConfig + +logger = get_logger(name=__name__, category="vector_io::mongodb") + +VERSION = "v1" +VECTOR_DBS_PREFIX = f"vector_dbs:mongodb:{VERSION}::" +OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:mongodb:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:mongodb:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:mongodb:{VERSION}::" + + +class MongoDBIndex(EmbeddingIndex): + """MongoDB Atlas Vector Search index implementation optimized for RAG.""" + + def __init__( + self, + vector_store: VectorStore, + collection: Collection, + config: MongoDBVectorIOConfig, + ): + self.vector_store = vector_store + self.collection = collection + self.config = config + self.dimension = vector_store.embedding_dimension + + async def initialize(self) -> None: + """Initialize the MongoDB collection and ensure vector search index exists.""" + try: + # Create the collection if it doesn't exist + collection_names = self.collection.database.list_collection_names() + if self.collection.name not in collection_names: + logger.info(f"Creating collection '{self.collection.name}'") + # Create collection by inserting a dummy document + dummy_doc = {"_id": "__dummy__", "dummy": True} + self.collection.insert_one(dummy_doc) + # Remove the dummy document + self.collection.delete_one({"_id": "__dummy__"}) + logger.info(f"Collection '{self.collection.name}' created successfully") + + # Create optimized vector search index for RAG + await self._create_vector_search_index() + + # Create text index for hybrid search + await self._ensure_text_index() + + except Exception as e: + logger.exception( + f"Failed to initialize MongoDB index for vector_store: {self.vector_store.identifier}. " + f"Collection name: {self.collection.name}. Error: {str(e)}" + ) + # Don't fail completely - just log the error and continue + logger.warning( + "Continuing without complete index initialization. " + "You may need to create indexes manually in MongoDB Atlas dashboard." + ) + + async def _create_vector_search_index(self) -> None: + """Create optimized vector search index based on MongoDB RAG best practices.""" + try: + # Check if vector search index exists + indexes = list(self.collection.list_search_indexes()) + index_exists = any(idx.get("name") == self.config.index_name for idx in indexes) + + if not index_exists: + # Create vector search index optimized for RAG + # Based on MongoDB's RAG example using new vectorSearch format + search_index_model = SearchIndexModel( + definition={ + "fields": [ + { + "type": "vector", + "numDimensions": self.dimension, + "path": self.config.path_field, + "similarity": self._convert_similarity_metric(self.config.similarity_metric), + } + ] + }, + name=self.config.index_name, + type="vectorSearch", + ) + + logger.info( + f"Creating vector search index '{self.config.index_name}' for RAG on collection '{self.collection.name}'" + ) + + self.collection.create_search_index(model=search_index_model) + + # Wait for index to be ready (like in MongoDB RAG example) + await self._wait_for_index_ready() + + logger.info("Vector search index created and ready for RAG queries") + + except Exception as e: + logger.warning(f"Failed to create vector search index: {e}") + + def _convert_similarity_metric(self, metric: str) -> str: + """Convert internal similarity metric to MongoDB Atlas format.""" + metric_map = { + "cosine": "cosine", + "euclidean": "euclidean", + "dotProduct": "dotProduct", + "dot_product": "dotProduct", + } + return metric_map.get(metric, "cosine") + + async def _wait_for_index_ready(self) -> None: + """Wait for the vector search index to be ready, based on MongoDB RAG example.""" + logger.info("Waiting for vector search index to be ready...") + + max_wait_time = 300 # 5 minutes max wait + wait_interval = 5 + elapsed_time = 0 + + while elapsed_time < max_wait_time: + try: + indices = list(self.collection.list_search_indexes(self.config.index_name)) + if len(indices) and indices[0].get("queryable") is True: + logger.info(f"Vector search index '{self.config.index_name}' is ready for querying") + return + except Exception: + pass + + time.sleep(wait_interval) + elapsed_time += wait_interval + + logger.warning(f"Vector search index may not be fully ready after {max_wait_time}s") + + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray) -> None: + """Add chunks with embeddings to MongoDB collection optimized for RAG.""" + if len(chunks) != len(embeddings): + raise ValueError(f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}") + + documents = [] + for i, chunk in enumerate(chunks): + # Structure document for optimal RAG retrieval + doc = { + "_id": chunk.chunk_id, + "chunk_id": chunk.chunk_id, + "text": interleaved_content_as_str(chunk.content), # Key field for RAG context + "content": interleaved_content_as_str(chunk.content), # Backward compatibility + "metadata": chunk.metadata or {}, + "chunk_metadata": (chunk.chunk_metadata.model_dump() if chunk.chunk_metadata else {}), + self.config.path_field: embeddings[i].tolist(), # Vector embeddings + "document": chunk.model_dump(), # Full chunk data + } + documents.append(doc) + + try: + # Use upsert behavior for chunks + for doc in documents: + self.collection.replace_one({"_id": doc["_id"]}, doc, upsert=True) + + logger.debug(f"Successfully added {len(chunks)} chunks optimized for RAG to MongoDB collection") + except Exception as e: + logger.exception(f"Failed to add chunks to MongoDB collection: {e}") + raise + + async def query_vector( + self, + embedding: NDArray, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """Perform vector similarity search optimized for RAG based on MongoDB example.""" + try: + # Use MongoDB's vector search aggregation pipeline optimized for RAG + pipeline = [ + { + "$vectorSearch": { + "index": self.config.index_name, + "queryVector": embedding.tolist(), + "path": self.config.path_field, + "numCandidates": min(k * 10, 1000), # Cap at 1000 to prevent excessive candidates + "limit": k, + } + }, + { + "$project": { + "_id": 0, + "text": 1, # Primary field for RAG context + "content": 1, # Backward compatibility + "metadata": 1, + "chunk_metadata": 1, + "document": 1, + "score": {"$meta": "vectorSearchScore"}, + } + }, + {"$match": {"score": {"$gte": score_threshold}}}, + ] + + results = list(self.collection.aggregate(pipeline)) + + chunks = [] + scores = [] + for result in results: + score = result.get("score", 0.0) + if score >= score_threshold: + chunk_data = result.get("document", {}) + if chunk_data: + chunks.append(Chunk(**chunk_data)) + scores.append(float(score)) + + logger.debug(f"Vector search for RAG returned {len(chunks)} results") + return QueryChunksResponse(chunks=chunks, scores=scores) + + except Exception as e: + logger.exception(f"Vector search for RAG failed: {e}") + raise RuntimeError(f"Vector search for RAG failed: {e}") from e + + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """Perform text search using MongoDB's text search for RAG context retrieval.""" + try: + # Ensure text index exists + await self._ensure_text_index() + + pipeline: list[dict[str, Any]] = [ + {"$match": {"$text": {"$search": query_string}}}, + { + "$project": { + "_id": 0, + "text": 1, # Primary field for RAG context + "content": 1, # Backward compatibility + "metadata": 1, + "chunk_metadata": 1, + "document": 1, + "score": {"$meta": "textScore"}, + } + }, + {"$match": {"score": {"$gte": score_threshold}}}, + {"$sort": {"score": {"$meta": "textScore"}}}, + {"$limit": k}, + ] + + results = list(self.collection.aggregate(pipeline)) + + chunks = [] + scores = [] + for result in results: + score = result.get("score", 0.0) + if score >= score_threshold: + chunk_data = result.get("document", {}) + if chunk_data: + chunks.append(Chunk(**chunk_data)) + scores.append(float(score)) + + logger.debug(f"Keyword search for RAG returned {len(chunks)} results") + return QueryChunksResponse(chunks=chunks, scores=scores) + + except Exception as e: + logger.exception(f"Keyword search for RAG failed: {e}") + raise RuntimeError(f"Keyword search for RAG failed: {e}") from e + + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + reranker_type: str, + reranker_params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + """Perform hybrid search for enhanced RAG context retrieval.""" + if reranker_params is None: + reranker_params = {} + + # Get results from both search methods + vector_response = await self.query_vector(embedding, k, 0.0) + keyword_response = await self.query_keyword(query_string, k, 0.0) + + # Convert responses to score dictionaries + vector_scores = { + chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) + } + keyword_scores = { + chunk.chunk_id: score + for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) + } + + # Combine scores using the reranking utility + combined_scores = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type, reranker_params + ) + + # Get top-k results + top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1]) + + # Filter by score threshold + filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold] + + # Create chunk map + chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks} + + # Build final results + chunks = [] + scores = [] + for doc_id, score in filtered_items: + if doc_id in chunk_map: + chunks.append(chunk_map[doc_id]) + scores.append(score) + + logger.debug(f"Hybrid search for RAG returned {len(chunks)} results") + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from MongoDB collection.""" + chunk_ids = [c.chunk_id for c in chunks_for_deletion] + try: + result = self.collection.delete_many({"_id": {"$in": chunk_ids}}) + logger.debug(f"Deleted {result.deleted_count} chunks from MongoDB collection") + except Exception as e: + logger.exception(f"Failed to delete chunks: {e}") + raise + + async def delete(self) -> None: + """Delete the entire collection.""" + try: + self.collection.drop() + logger.debug(f"Dropped MongoDB collection: {self.collection.name}") + except Exception as e: + logger.exception(f"Failed to drop collection: {e}") + raise + + async def _ensure_text_index(self) -> None: + """Ensure text search index exists on content fields for RAG.""" + try: + indexes = list(self.collection.list_indexes()) + text_index_exists = any( + any(key.startswith(("content", "text")) for key in idx.get("key", {}).keys()) + and idx.get("textIndexVersion") is not None + for idx in indexes + ) + + if not text_index_exists: + logger.info("Creating text search index on content fields for RAG") + # Index both 'text' and 'content' fields for comprehensive text search + self.collection.create_index([("text", "text"), ("content", "text")]) + logger.info("Text search index created successfully for RAG") + + except Exception as e: + logger.warning(f"Failed to create text index for RAG: {e}") + + +class MongoDBVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): + """MongoDB Atlas Vector Search adapter for Llama Stack optimized for RAG workflows.""" + + def __init__( + self, + config: MongoDBVectorIOConfig, + inference_api, + files_api=None, + models_api=None, + ) -> None: + # Handle the case where files_api might be a ProviderSpec that needs resolution + resolved_files_api = files_api + super().__init__(files_api=resolved_files_api, kvstore=None) + self.config = config + self.inference_api = inference_api + self.models_api = models_api + self.client: MongoClient | None = None + self.database: Database | None = None + self.cache: dict[str, VectorStoreWithIndex] = {} + self.kvstore: KVStore | None = None + + async def initialize(self) -> None: + """Initialize MongoDB connection optimized for RAG workflows.""" + logger.info("Initializing MongoDB Atlas Vector IO adapter for RAG") + + try: + # Initialize KV store for metadata + if self.config.persistence: + self.kvstore = await kvstore_impl(self.config.persistence) + + # Skip MongoDB connection if no connection string provided + # This allows other providers to work without MongoDB credentials + if not self.config.connection_string: + logger.warning( + "MongoDB connection_string not provided. " + "MongoDB vector store will not be available until credentials are configured." + ) + return + + # Connect to MongoDB with optimized settings for RAG + self.client = MongoClient( + self.config.connection_string, + server_api=ServerApi("1"), + maxPoolSize=self.config.max_pool_size, + serverSelectionTimeoutMS=self.config.timeout_ms, + # Additional settings for RAG performance + retryWrites=True, + readPreference="primaryPreferred", + ) + + # Test connection + self.client.admin.command("ping") + logger.info("Successfully connected to MongoDB Atlas for RAG") + + # Get database + self.database = self.client[self.config.database_name] + + # Initialize OpenAI vector stores + await self.initialize_openai_vector_stores() + + # Load existing vector databases + await self._load_existing_vector_dbs() + + logger.info("MongoDB Atlas Vector IO adapter for RAG initialized successfully") + + except Exception as e: + logger.exception("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") + raise RuntimeError("Failed to initialize MongoDB Atlas Vector IO adapter for RAG") from e + + async def shutdown(self) -> None: + """Shutdown MongoDB connection.""" + if self.client: + self.client.close() + logger.info("MongoDB Atlas RAG connection closed") + + async def health(self) -> HealthResponse: + """Perform health check on MongoDB connection.""" + try: + if self.client: + self.client.admin.command("ping") + return HealthResponse(status=HealthStatus.OK) + else: + return HealthResponse(status=HealthStatus.ERROR, message="MongoDB client not initialized") + except Exception as e: + return HealthResponse( + status=HealthStatus.ERROR, + message=f"MongoDB RAG health check failed: {str(e)}", + ) + + async def register_vector_store(self, vector_store: VectorStore) -> None: + """Register a new vector store optimized for RAG.""" + if self.database is None: + raise RuntimeError("MongoDB database not initialized") + + # Create collection name from vector store identifier + collection_name = sanitize_collection_name(vector_store.identifier) + collection = self.database[collection_name] + + # Create and initialize MongoDB index optimized for RAG + mongodb_index = MongoDBIndex(vector_store, collection, self.config) + await mongodb_index.initialize() + + # Create vector store with index wrapper + vector_store_with_index = VectorStoreWithIndex( + vector_store=vector_store, + index=mongodb_index, + inference_api=self.inference_api, + ) + + # Cache the vector store + self.cache[vector_store.identifier] = vector_store_with_index + + # Save vector store info to KVStore for persistence + if self.kvstore: + await self.kvstore.set( + f"{VECTOR_DBS_PREFIX}{vector_store.identifier}", + vector_store.model_dump_json(), + ) + + logger.info(f"Registered vector store for RAG: {vector_store.identifier}") + + async def unregister_vector_store(self, vector_store_id: str) -> None: + """Unregister a vector store.""" + if vector_store_id in self.cache: + await self.cache[vector_store_id].index.delete() + del self.cache[vector_store_id] + + # Clean up from KV store + if self.kvstore: + await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_store_id}") + + logger.info(f"Unregistered vector store: {vector_store_id}") + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + """Insert chunks into the vector database optimized for RAG.""" + vector_db_with_index = await self._get_vector_db_index(vector_db_id) + await vector_db_with_index.insert_chunks(chunks) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + """Query chunks from the vector database optimized for RAG context retrieval.""" + vector_db_with_index = await self._get_vector_db_index(vector_db_id) + return await vector_db_with_index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + """Delete chunks from the vector database.""" + vector_db_with_index = await self._get_vector_db_index(store_id) + await vector_db_with_index.index.delete_chunks(chunks_for_deletion) + + async def _get_vector_db_index(self, vector_db_id: str) -> VectorStoreWithIndex: + """Get vector store index from cache.""" + if vector_db_id in self.cache: + return self.cache[vector_db_id] + + raise VectorStoreNotFoundError(vector_db_id) + + async def _load_existing_vector_dbs(self) -> None: + """Load existing vector databases from KVStore.""" + if not self.kvstore: + return + + try: + # Use keys_in_range to get all vector database keys from KVStore + # This searches for keys with the prefix by using range scan + start_key = VECTOR_DBS_PREFIX + # Create an end key by incrementing the last character + end_key = VECTOR_DBS_PREFIX[:-1] + chr(ord(VECTOR_DBS_PREFIX[-1]) + 1) + + vector_db_keys = await self.kvstore.keys_in_range(start_key, end_key) + + for key in vector_db_keys: + try: + vector_store_data = await self.kvstore.get(key) + if vector_store_data: + import json + + vector_store = VectorStore(**json.loads(vector_store_data)) + # Register the vector store without re-initializing + await self._register_existing_vector_store(vector_store) + logger.info(f"Loaded existing RAG-optimized vector store: {vector_store.identifier}") + except Exception as e: + logger.warning(f"Failed to load vector store from key {key}: {e}") + continue + + except Exception as e: + logger.warning(f"Failed to load existing vector stores: {e}") + + async def _register_existing_vector_store(self, vector_store: VectorStore) -> None: + """Register an existing vector store without re-initialization.""" + if self.database is None: + raise RuntimeError("MongoDB database not initialized") + + # Create collection name from vector store identifier + collection_name = sanitize_collection_name(vector_store.identifier) + collection = self.database[collection_name] + + # Create MongoDB index without initialization (collection already exists) + mongodb_index = MongoDBIndex(vector_store, collection, self.config) + + # Create vector store with index wrapper + vector_store_with_index = VectorStoreWithIndex( + vector_store=vector_store, + index=mongodb_index, + inference_api=self.inference_api, + ) + + # Cache the vector store + self.cache[vector_store.identifier] = vector_store_with_index diff --git a/tests/unit/providers/vector_io/__init__.py b/tests/unit/providers/vector_io/__init__.py new file mode 100644 index 0000000000..756f351d88 --- /dev/null +++ b/tests/unit/providers/vector_io/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree.