Skip to content

Commit

Permalink
community: FAISS vectorstore - consistent Document id field (#28728)
Browse files Browse the repository at this point in the history
make sure id field of Documents in `FAISS` docstore have the same id as
values in `index_to_docstore_id`, implement `get_by_ids` method
  • Loading branch information
nhols authored Dec 15, 2024
1 parent a0534ae commit a3851cb
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 133 deletions.
16 changes: 11 additions & 5 deletions libs/community/langchain_community/vectorstores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Iterable,
List,
Optional,
Sequence,
Sized,
Tuple,
Union,
Expand Down Expand Up @@ -284,33 +285,34 @@ def __add(
ids: Optional[List[str]] = None,
) -> List[str]:
faiss = dependable_faiss_import()

if not isinstance(self.docstore, AddableMixin):
raise ValueError(
"If trying to add texts, the underlying docstore should support "
f"adding items, which {self.docstore} does not"
)

_len_check_if_sized(texts, metadatas, "texts", "metadatas")

ids = ids or [str(uuid.uuid4()) for _ in texts]
_len_check_if_sized(texts, ids, "texts", "ids")

_metadatas = metadatas or ({} for _ in texts)
documents = [
Document(page_content=t, metadata=m) for t, m in zip(texts, _metadatas)
Document(id=id_, page_content=t, metadata=m)
for id_, t, m in zip(ids, texts, _metadatas)
]

_len_check_if_sized(documents, embeddings, "documents", "embeddings")
_len_check_if_sized(documents, ids, "documents", "ids")

if ids and len(ids) != len(set(ids)):
raise ValueError("Duplicate ids found in the ids list.")

# Add to the index.
vector = np.array(embeddings, dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
self.index.add(vector)

# Add information to docstore and index.
ids = ids or [str(uuid.uuid4()) for _ in texts]
self.docstore.add({id_: doc for id_, doc in zip(ids, documents)})
starting_len = len(self.index_to_docstore_id)
index_to_id = {starting_len + j: id_ for j, id_ in enumerate(ids)}
Expand Down Expand Up @@ -1475,3 +1477,7 @@ def filter_func(filter: Dict[str, Any]) -> Callable[[Dict[str, Any]], bool]:
return lambda doc: all(condition(doc) for condition in conditions)

return filter_func(filter)

def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
docs = [self.docstore.search(id_) for id_ in ids]
return [doc for doc in docs if isinstance(doc, Document)]
Loading

0 comments on commit a3851cb

Please sign in to comment.