Skip to content

Commit f5a9bf4

Browse files
authored
fix: chromadb max batch size (#1087)
1 parent b46c108 commit f5a9bf4

File tree

5 files changed

+142
-68
lines changed

5 files changed

+142
-68
lines changed

poetry.lock

+20-62
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from typing import Any
2+
3+
from llama_index.schema import BaseNode, MetadataMode
4+
from llama_index.vector_stores import ChromaVectorStore
5+
from llama_index.vector_stores.chroma import chunk_list
6+
from llama_index.vector_stores.utils import node_to_metadata_dict
7+
8+
9+
class BatchedChromaVectorStore(ChromaVectorStore):
10+
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
11+
12+
In this vector store, embeddings are stored within a ChromaDB collection.
13+
14+
During query time, the index uses ChromaDB to query for the top
15+
k most similar nodes.
16+
17+
Args:
18+
chroma_client (from chromadb.api.API):
19+
API instance
20+
chroma_collection (chromadb.api.models.Collection.Collection):
21+
ChromaDB collection instance
22+
23+
"""
24+
25+
chroma_client: Any | None
26+
27+
def __init__(
28+
self,
29+
chroma_client: Any,
30+
chroma_collection: Any,
31+
host: str | None = None,
32+
port: str | None = None,
33+
ssl: bool = False,
34+
headers: dict[str, str] | None = None,
35+
collection_kwargs: dict[Any, Any] | None = None,
36+
) -> None:
37+
super().__init__(
38+
chroma_collection=chroma_collection,
39+
host=host,
40+
port=port,
41+
ssl=ssl,
42+
headers=headers,
43+
collection_kwargs=collection_kwargs or {},
44+
)
45+
self.chroma_client = chroma_client
46+
47+
def add(self, nodes: list[BaseNode]) -> list[str]:
48+
"""Add nodes to index, batching the insertion to avoid issues.
49+
50+
Args:
51+
nodes: List[BaseNode]: list of nodes with embeddings
52+
53+
"""
54+
if not self.chroma_client:
55+
raise ValueError("Client not initialized")
56+
57+
if not self._collection:
58+
raise ValueError("Collection not initialized")
59+
60+
max_chunk_size = self.chroma_client.max_batch_size
61+
node_chunks = chunk_list(nodes, max_chunk_size)
62+
63+
all_ids = []
64+
for node_chunk in node_chunks:
65+
embeddings = []
66+
metadatas = []
67+
ids = []
68+
documents = []
69+
for node in node_chunk:
70+
embeddings.append(node.get_embedding())
71+
metadatas.append(
72+
node_to_metadata_dict(
73+
node, remove_text=True, flat_metadata=self.flat_metadata
74+
)
75+
)
76+
ids.append(node.node_id)
77+
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
78+
79+
self._collection.add(
80+
embeddings=embeddings,
81+
ids=ids,
82+
metadatas=metadatas,
83+
documents=documents,
84+
)
85+
all_ids.extend(ids)
86+
87+
return all_ids

private_gpt/components/vector_store/vector_store_component.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from injector import inject, singleton
55
from llama_index import VectorStoreIndex
66
from llama_index.indices.vector_store import VectorIndexRetriever
7-
from llama_index.vector_stores import ChromaVectorStore
87
from llama_index.vector_stores.types import VectorStore
98

9+
from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
1010
from private_gpt.open_ai.extensions.context_filter import ContextFilter
1111
from private_gpt.paths import local_data_path
1212

@@ -36,14 +36,16 @@ class VectorStoreComponent:
3636

3737
@inject
3838
def __init__(self) -> None:
39-
db = chromadb.PersistentClient(
39+
chroma_client = chromadb.PersistentClient(
4040
path=str((local_data_path / "chroma_db").absolute())
4141
)
42-
chroma_collection = db.get_or_create_collection(
42+
chroma_collection = chroma_client.get_or_create_collection(
4343
"make_this_parameterizable_per_api_call"
4444
) # TODO
4545

46-
self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
46+
self.vector_store = BatchedChromaVectorStore(
47+
chroma_client=chroma_client, chroma_collection=chroma_collection
48+
)
4749

4850
@staticmethod
4951
def get_retriever(

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ description = "Private GPT"
55
authors = ["Zylon <hi@zylon.ai>"]
66

77
[tool.poetry.dependencies]
8-
python = ">=3.11,<3.13"
8+
python = ">=3.11,<3.12"
99
fastapi = { extras = ["all"], version = "^0.103.1" }
1010
loguru = "^0.7.2"
1111
boto3 = "^1.28.56"
1212
injector = "^0.21.0"
1313
pyyaml = "^6.0.1"
1414
python-multipart = "^0.0.6"
1515
pypdf = "^3.16.2"
16-
llama-index = "v0.8.35"
16+
llama-index = "0.8.47"
1717
chromadb = "^0.4.13"
1818
watchdog = "^3.0.0"
1919
transformers = "^4.34.0"
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from unittest.mock import PropertyMock, patch
2+
3+
from llama_index import Document
4+
5+
from private_gpt.server.ingest.ingest_service import IngestService
6+
from tests.fixtures.mock_injector import MockInjector
7+
8+
9+
def test_save_many_nodes(injector: MockInjector) -> None:
10+
"""This is a specific test for a local Chromadb Vector Database setup.
11+
12+
Extend it when we add support for other vector databases in VectorStoreComponent.
13+
"""
14+
with patch(
15+
"chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock
16+
) as max_batch_size:
17+
# Make max batch size of Chromadb very small
18+
max_batch_size.return_value = 10
19+
20+
ingest_service = injector.get(IngestService)
21+
22+
documents = []
23+
for _i in range(100):
24+
documents.append(Document(text="This is a sentence."))
25+
26+
ingested_docs = ingest_service._save_docs(documents)
27+
assert len(ingested_docs) == len(documents)

0 commit comments

Comments
 (0)