diff --git a/README.md b/README.md index 918433d510..e2529cac1c 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ Here is a list of the various API providers and available distributions that can | NVIDIA NIM | Hosted and Single Node | | ✅ | | | | | Chroma | Single Node | | | ✅ | | | | PG Vector | Single Node | | | ✅ | | | +| MongoDB Atlas | Hosted | | | ✅ | | | | PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | vLLM | Hosted and Single Node | | ✅ | | | | | OpenAI | Hosted | | ✅ | | | | diff --git a/docs/source/providers/vector_io/mongodb.md b/docs/source/providers/vector_io/mongodb.md new file mode 100644 index 0000000000..67e4adec00 --- /dev/null +++ b/docs/source/providers/vector_io/mongodb.md @@ -0,0 +1,35 @@ +--- +orphan: true +--- +# MongoDB Atlas + +[MongoDB Atlas](https://www.mongodb.com/atlas) is a cloud database service that can be used as a vector store provider for Llama Stack. It supports vector search capabilities through its Atlas Vector Search feature, allowing you to store and query vectors within your MongoDB database. + +## Features +MongoDB Atlas Vector Search supports: +- Store embeddings and their metadata +- Vector search with multiple algorithms (cosine similarity, euclidean distance, dot product) +- Hybrid search (combining vector and keyword search) +- Metadata filtering +- Scalable vector indexing +- Managed cloud infrastructure + +## Usage + +To use MongoDB Atlas in your Llama Stack project, follow these steps: + +1. Create a MongoDB Atlas account and cluster. +2. Configure your Atlas cluster to enable Vector Search. +3. Configure your Llama Stack project to use MongoDB Atlas. +4. Start storing and querying vectors. + +## Installation + +You can install the MongoDB Python driver using pip: + +```bash +pip install pymongo +``` + +## Documentation +See [MongoDB Atlas Vector Search documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/) for more details about vector search capabilities in MongoDB Atlas. diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 93031763d1..ee2d4c706d 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -110,6 +110,16 @@ def available_providers() -> List[ProviderSpec]: ), api_dependencies=[Api.inference], ), + remote_provider_spec( + Api.vector_io, + AdapterSpec( + adapter_type="mongodb", + pip_packages=["pymongo"], + module="llama_stack.providers.remote.vector_io.mongodb", + config_class="llama_stack.providers.remote.vector_io.mongodb.MongoDBVectorIOConfig", + ), + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/vector_io/mongodb/__init__.py b/llama_stack/providers/remote/vector_io/mongodb/__init__.py new file mode 100644 index 0000000000..fd46551b18 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/mongodb/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import MongoDBVectorIOConfig + + +async def get_adapter_impl(config: MongoDBVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .mongodb import MongoDBVectorIOAdapter + + impl = MongoDBVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/vector_io/mongodb/config.py b/llama_stack/providers/remote/vector_io/mongodb/config.py new file mode 100644 index 0000000000..0ebaad9f6a --- /dev/null +++ b/llama_stack/providers/remote/vector_io/mongodb/config.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, Optional, List + +from pydantic import BaseModel, Field + + +class MongoDBVectorIOConfig(BaseModel): + connection_str: str = Field(None, description="Connection string for the MongoDB Atlas collection") + namespace: str = Field(None, description="Namespace i.e. db_name.collection_name of the MongoDB Atlas collection") + index_name: Optional[str] = Field("default", description="Name of the index in the MongoDB Atlas collection") + filter_fields: Optional[List[str]] = Field([], description="Fields to filter along side vector search in MongoDB Atlas collection") + embeddings_key: Optional[str] = Field("embeddings", description="Field name for the embeddings in the MongoDB Atlas collection") + text_field: Optional[str] = Field("text", description="Field name for the text in the MongoDB Atlas collection") + + + @classmethod + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> Dict[str, Any]: + return { + "connection_str": "{env.MONGODB_CONNECTION_STR}", + "namespace": "{env.MONGODB_NAMESPACE}", + } \ No newline at end of file diff --git a/llama_stack/providers/remote/vector_io/mongodb/mongodb.py b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py new file mode 100644 index 0000000000..cc172fc706 --- /dev/null +++ b/llama_stack/providers/remote/vector_io/mongodb/mongodb.py @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import asyncio +import json +import logging +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse + +from pymongo import MongoClient +from pymongo.operations import InsertOne, SearchIndexModel, UpdateOne +import certifi +from numpy.typing import NDArray + +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO + +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import ( + EmbeddingIndex, + VectorDBWithIndex, +) + +from .config import MongoDBVectorIOConfig + +from time import sleep + +log = logging.getLogger(__name__) +CHUNK_ID_KEY = "_chunk_id" + + +class MongoDBAtlasIndex(EmbeddingIndex): + + def __init__(self, client: MongoClient, namespace: str, embeddings_key: str, embedding_dimension: str, index_name: str, filter_fields: List[str]): + self.client = client + self.namespace = namespace + self.embeddings_key = embeddings_key + self.index_name = index_name + self.filter_fields = filter_fields + self.embedding_dimension = embedding_dimension + + def _get_index_config(self, collection, index_name): + idxs = list(collection.list_search_indexes()) + for ele in idxs: + if ele["name"] == index_name: + return ele + + def _get_search_index_model(self): + index_fields = [ + { + "path": self.embeddings_key, + "type": "vector", + "numDimensions": self.embedding_dimension, + "similarity": "cosine" + } + ] + + if len(self.filter_fields) > 0: + for filter_field in self.filter_fields: + index_fields.append( + { + "path": filter_field, + "type": "filter" + } + ) + + return SearchIndexModel( + name=self.index_name, + type="vectorSearch", + definition={ + "fields": index_fields + } + ) + + def _check_n_create_index(self): + client = self.client + db, collection = self.namespace.split(".") + collection = client[db][collection] + index_name = self.index_name + idx = self._get_index_config(collection, index_name) + if not idx: + log.info("Creating search index ...") + search_index_model = self._get_search_index_model() + collection.create_search_index(search_index_model) + while True: + idx = self._get_index_config(collection, index_name) + if idx and idx["queryable"]: + log.info("Search index created successfully.") + break + else: + log.info("Waiting for search index to be created ...") + sleep(5) + else: + log.info("Search index already exists.") + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + assert len(chunks) == len(embeddings), ( + f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" + ) + + # Create a list of operations to perform in bulk + operations = [] + for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): + chunk_id = f"{chunk.metadata['document_id']}:chunk-{i}" + + operations.append( + UpdateOne( + {CHUNK_ID_KEY: chunk_id}, + { + "$set": { + CHUNK_ID_KEY: chunk_id, + "chunk_content": json.loads(chunk.model_dump_json()), + self.embeddings_key: embedding.tolist(), + } + }, + upsert=True, + ) + ) + + # Perform the bulk operations + db, collection_name = self.namespace.split(".") + collection = self.client[db][collection_name] + collection.bulk_write(operations) + print(f"Added {len(chunks)} chunks to the collection") + # Create a search index model + print("Creating search index ...") + self._check_n_create_index() + + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + # Perform a query + db, collection_name = self.namespace.split(".") + collection = self.client[db][collection_name] + + # Create vector search query + vs_query = {"$vectorSearch": + { + "index": self.index_name, + "path": self.embeddings_key, + "queryVector": embedding.tolist(), + "numCandidates": k, + "limit": k, + } + } + # Add a field to store the score + score_add_field_query = { + "$addFields": { + "score": {"$meta": "vectorSearchScore"} + } + } + if score_threshold is None: + score_threshold = 0.01 + # Filter the results based on the score threshold + filter_query = { + "$match": { + "score": {"$gt": score_threshold} + } + } + + project_query = { + "$project": { + CHUNK_ID_KEY: 1, + "chunk_content": 1, + "score": 1, + "_id": 0, + } + } + + query = [vs_query, score_add_field_query, filter_query, project_query] + + results = collection.aggregate(query) + + # Create the response + chunks = [] + scores = [] + for result in results: + content = result["chunk_content"] + chunk = Chunk( + metadata=content["metadata"], + content=content["content"], + ) + chunks.append(chunk) + scores.append(result["score"]) + + return QueryChunksResponse(chunks=chunks, scores=scores) + + async def delete(self): + db, _ = self.namespace.split(".") + self.client.drop_database(db) + + +class MongoDBVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: MongoDBVectorIOConfig, inference_api: Api.inference): + self.config = config + self.inference_api = inference_api + self.cache = {} + + async def initialize(self) -> None: + self.client = MongoClient( + self.config.connection_str, + tlsCAFile=certifi.where(), + ) + + async def shutdown(self) -> None: + if not self.client: + self.client.close() + + async def register_vector_db(self, vector_db: VectorDB) -> None: + index=MongoDBAtlasIndex( + client=self.client, + namespace=self.config.namespace, + embeddings_key=self.config.embeddings_key, + embedding_dimension=vector_db.embedding_dimension, + index_name=self.config.index_name, + filter_fields=self.config.filter_fields, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db=vector_db, + index=index, + inference_api=self.inference_api, + ) + + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + self.cache[vector_db_id] = VectorDBWithIndex( + vector_db=vector_db_id, + index=MongoDBAtlasIndex( + client=self.client, + namespace=self.config.namespace, + embeddings_key=self.config.embeddings_key, + embedding_dimension=vector_db.embedding_dimension, + index_name=self.config.index_name, + filter_fields=self.config.filter_fields, + ), + inference_api=self.inference_api, + ) + return self.cache[vector_db_id] + + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + + async def insert_chunks(self, + vector_db_id: str, + chunks: List[Chunk], + ttl_seconds: Optional[int] = None, + ) -> None: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + await index.insert_chunks(chunks) + + async def query_chunks(self, + vector_db_id: str, + query: InterleavedContent, + params: Optional[Dict[str, Any]] = None, + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + if not index: + raise ValueError(f"Vector DB {vector_db_id} not found") + return await index.query_chunks(query, params)