Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for MongoDB Atlas $vectorSearch vector search #11139

Merged
merged 14 commits into from
Sep 28, 2023
180 changes: 130 additions & 50 deletions libs/langchain/langchain/vectorstores/mongodb_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

import numpy as np
from pymongo.errors import OperationFailure

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self._index_name = index_name
self._text_key = text_key
self._embedding_key = embedding_key
self._use_vectorsearch = True

@property
def embeddings(self) -> Embeddings:
Expand All @@ -89,6 +91,18 @@ def from_connection_string(
embedding: Embeddings,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct a `MongoDB Atlas Vector Search` vector store
from a MongoDB connection URI.

Args:
connection_string: A valid MongoDB connection URI.
namespace: A valid MongoDB namespace (database and collection).
embedding: The text embedding model to use for the vector store.

Returns:
A new MongoDBAtlasVectorSearch instance.

"""
try:
from pymongo import MongoClient
except ImportError:
Expand Down Expand Up @@ -145,29 +159,21 @@ def _insert_texts(self, texts: List[str], metadatas: List[Dict[str, Any]]) -> Li
insert_result = self._collection.insert_many(to_insert) # type: ignore
return insert_result.inserted_ids

def _similarity_search_with_score(
def _similarity_search_query(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[dict] = None,
query: Dict[str, Any],
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
knn_beta = {
"vector": embedding,
"path": self._embedding_key,
"k": k,
}
if pre_filter:
knn_beta["filter"] = pre_filter
pipeline = [
{
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
},
{"$set": {"score": {"$meta": "searchScore"}}},
]
if self._use_vectorsearch:
pipeline = [
query,
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
else:
pipeline = [
query,
{"$set": {"score": {"$meta": "searchScore"}}},
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
Expand All @@ -178,17 +184,91 @@ def _similarity_search_with_score(
docs.append((Document(page_content=text, metadata=res), score))
return docs

def _similarity_search_query_search(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
knn_beta = {
"vector": embedding,
"path": self._embedding_key,
"k": k,
}
if pre_filter:
knn_beta["filter"] = pre_filter
query = {
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
}

return self._similarity_search_query(query, post_filter_pipeline)

def _similarity_search_query_vectorsearch(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
params = {
"queryVector": embedding,
"path": self._embedding_key,
"numCandidates": k * 10,
"limit": k,
"index": self._index_name,
}
if pre_filter:
params["filter"] = pre_filter
query = {"$vectorSearch": params}

return self._similarity_search_query(query, post_filter_pipeline)

def _similarity_search_with_score(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
if self._use_vectorsearch:
try:
result = self._similarity_search_query_vectorsearch(
embedding, k, pre_filter, post_filter_pipeline
)
except OperationFailure as e:
# QueryFeatureNotAllowed or unknown pipeline stage $vectorSearch
if e.code == 224 or "$vectorSearch" in str(e):
logger.error(
f"$vectorSearch not supported for this Atlas version. "
f"Attempting to use $search. Original error:\n\t{e}"
)
self._use_vectorsearch = False
result = self._similarity_search_query_search(
embedding, k, pre_filter, post_filter_pipeline
)
else:
raise
else:
result = self._similarity_search_query_search(
embedding, k, pre_filter, post_filter_pipeline
)
return result

def similarity_search_with_score(
self,
query: str,
*,
k: int = 4,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
"""Return MongoDB documents most similar to query, along with scores.
"""Return MongoDB documents most similar to the given query and their scores.

Use the knnBeta Operator available in MongoDB Atlas Search
Uses the knnBeta Operator available in MongoDB Atlas Search.
This feature is in early access and available only for evaluation purposes, to
validate functionality, and to gather feedback from a small closed group of
early access users. It is not recommended for production deployments as we
Expand All @@ -197,14 +277,14 @@ def similarity_search_with_score(

Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: (Optional) dictionary of argument(s) to prefilter document
fields on.
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
following the knnBeta vector search.

Returns:
List of Documents most similar to the query and score for each
List of documents most similar to the query and their scores.
"""
embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
Expand All @@ -219,29 +299,29 @@ def similarity_search(
self,
query: str,
k: int = 4,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return MongoDB documents most similar to query.
"""Return MongoDB documents most similar to the given query.

Use the knnBeta Operator available in MongoDB Atlas Search
Uses the knnBeta Operator available in MongoDB Atlas Search.
This feature is in early access and available only for evaluation purposes, to
validate functionality, and to gather feedback from a small closed group of
early access users. It is not recommended for production deployments as we may
introduce breaking changes.
early access users. It is not recommended for production deployments as we
may introduce breaking changes.
For more: https://www.mongodb.com/docs/atlas/atlas-search/knn-beta

Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
k: (Optional) number of documents to return. Defaults to 4.
pre_filter: (Optional) dictionary of argument(s) to prefilter document
fields on.
post_filter_pipeline: (Optional) Pipeline of MongoDB aggregation stages
following the knnBeta vector search.

Returns:
List of Documents most similar to the query and score for each
List of documents most similar to the query and their scores.
"""
docs_and_scores = self.similarity_search_with_score(
query,
Expand All @@ -257,30 +337,30 @@ def max_marginal_relevance_search(
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[dict] = None,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
"""Return documents selected using the maximal marginal relevance.

Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.

Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
fetch_k: Optional Number of Documents to fetch before passing to MMR
k: (Optional) number of documents to return. Defaults to 4.
fetch_k: (Optional) number of documents to fetch before passing to MMR
algorithm. Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
pre_filter: (Optional) dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
post_filter_pipeline: (Optional) pipeline of MongoDB aggregation stages
following the knnBeta vector search.
Returns:
List of Documents selected by maximal marginal relevance.
List of documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
Expand All @@ -303,11 +383,11 @@ def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
metadatas: Optional[List[Dict]] = None,
collection: Optional[Collection[MongoDBDocumentType]] = None,
**kwargs: Any,
) -> MongoDBAtlasVectorSearch:
"""Construct MongoDBAtlasVectorSearch wrapper from raw documents.
"""Construct a `MongoDB Atlas Vector Search` vector store from raw documents.

This is a user-friendly interface that:
1. Embeds documents.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,35 @@

import os
from time import sleep
from typing import TYPE_CHECKING, Any

import pytest
from pymongo import MongoClient

from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch

if TYPE_CHECKING:
from pymongo import MongoClient

INDEX_NAME = "langchain-test-index"
NAMESPACE = "langchain_test_db.langchain_test_collection"
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")

# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
# connections.


@pytest.fixture
def collection() -> Any:
test_client = MongoClient(CONNECTION_STRING)
return test_client[DB_NAME][COLLECTION_NAME]
test_client = MongoClient(CONNECTION_STRING)
collection = test_client[DB_NAME][COLLECTION_NAME]


class TestMongoDBAtlasVectorSearch:
@classmethod
def setup_class(cls, collection: Any) -> None:
def setup_class(cls) -> None:
# insure the test collection is empty
assert collection.count_documents({}) == 0 # type: ignore[index] # noqa: E501

@classmethod
def teardown_class(cls, collection: Any) -> None:
def teardown_class(cls) -> None:
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]

@pytest.fixture(autouse=True)
def setup(self, collection: Any) -> None:
def setup(self) -> None:
# delete all the documents in the collection
collection.delete_many({}) # type: ignore[index]

Expand Down