Skip to content

Commit

Permalink
fix: impl missing embeddings method (langchain-ai#10823)
Browse files Browse the repository at this point in the history
FAISS does not implement embeddings method and use embed_query to
embedding texts which is wrong for some embedding models.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
sonald and baskaryan authored Oct 19, 2023
1 parent 2661dc9 commit 77fc2f7
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions libs/langchain/langchain/vectorstores/faiss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import operator
import os
import pickle
Expand All @@ -15,6 +16,7 @@
Optional,
Sized,
Tuple,
Union,
)

import numpy as np
Expand All @@ -26,6 +28,8 @@
from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance

logger = logging.getLogger(__name__)


def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
"""
Expand Down Expand Up @@ -82,7 +86,7 @@ class FAISS(VectorStore):

def __init__(
self,
embedding_function: Callable,
embedding_function: Union[Callable, Embeddings],
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
Expand All @@ -91,6 +95,11 @@ def __init__(
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
):
"""Initialize with necessary components."""
if not isinstance(embedding_function, Embeddings):
logger.warning(
"`embedding_function` is expected to be an Embeddings object, support "
"for passing in a function will soon be removed."
)
self.embedding_function = embedding_function
self.index = index
self.docstore = docstore
Expand All @@ -108,6 +117,26 @@ def __init__(
)
)

@property
def embeddings(self) -> Optional[Embeddings]:
return (
self.embedding_function
if isinstance(self.embedding_function, Embeddings)
else None
)

def _embed_documents(self, texts: List[str]) -> List[List[float]]:
if isinstance(self.embedding_function, Embeddings):
return self.embedding_function.embed_documents(texts)
else:
return [self.embedding_function(text) for text in texts]

def _embed_query(self, text: str) -> List[float]:
if isinstance(self.embedding_function, Embeddings):
return self.embedding_function.embed_query(text)
else:
return self.embedding_function(text)

def __add(
self,
texts: Iterable[str],
Expand Down Expand Up @@ -163,7 +192,8 @@ def add_texts(
Returns:
List of ids from adding the texts into the vectorstore.
"""
embeddings = [self.embedding_function(text) for text in texts]
texts = list(texts)
embeddings = self._embed_documents(texts)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)

def add_embeddings(
Expand Down Expand Up @@ -272,7 +302,7 @@ def similarity_search_with_score(
List of documents most similar to the query text with
L2 distance in float. Lower score represents more similarity.
"""
embedding = self.embedding_function(query)
embedding = self._embed_query(query)
docs = self.similarity_search_with_score_by_vector(
embedding,
k,
Expand Down Expand Up @@ -465,7 +495,7 @@ def max_marginal_relevance_search(
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self.embedding_function(query)
embedding = self._embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k=k,
Expand Down Expand Up @@ -561,7 +591,7 @@ def __from(
# Default to L2, currently other metric types not initialized.
index = faiss.IndexFlatL2(len(embeddings[0]))
vecstore = cls(
embedding.embed_query,
embedding,
index,
InMemoryDocstore(),
{},
Expand Down Expand Up @@ -696,9 +726,7 @@ def load_local(
# load docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
docstore, index_to_docstore_id = pickle.load(f)
return cls(
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)

def serialize_to_bytes(self) -> bytes:
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
Expand All @@ -713,9 +741,7 @@ def deserialize_from_bytes(
) -> FAISS:
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
index, docstore, index_to_docstore_id = pickle.loads(serialized)
return cls(
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
)
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)

def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
Expand Down

0 comments on commit 77fc2f7

Please sign in to comment.