Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: chromadb max batch size #1084 #1087

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 20 additions & 62 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 87 additions & 0 deletions private_gpt/components/vector_store/batched_chroma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any

from llama_index.schema import BaseNode, MetadataMode
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.chroma import chunk_list
from llama_index.vector_stores.utils import node_to_metadata_dict


class BatchedChromaVectorStore(ChromaVectorStore):
"""Chroma vector store, batching additions to avoid reaching the max batch limit.

In this vector store, embeddings are stored within a ChromaDB collection.

During query time, the index uses ChromaDB to query for the top
k most similar nodes.

Args:
chroma_client (from chromadb.api.API):
API instance
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance

"""

chroma_client: Any | None

def __init__(
self,
chroma_client: Any,
chroma_collection: Any,
host: str | None = None,
port: str | None = None,
ssl: bool = False,
headers: dict[str, str] | None = None,
collection_kwargs: dict[Any, Any] | None = None,
) -> None:
super().__init__(
chroma_collection=chroma_collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_kwargs=collection_kwargs or {},
)
self.chroma_client = chroma_client

def add(self, nodes: list[BaseNode]) -> list[str]:
"""Add nodes to index, batching the insertion to avoid issues.

Args:
nodes: List[BaseNode]: list of nodes with embeddings

"""
if not self.chroma_client:
raise ValueError("Client not initialized")

if not self._collection:
raise ValueError("Collection not initialized")

max_chunk_size = self.chroma_client.max_batch_size
node_chunks = chunk_list(nodes, max_chunk_size)

all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))

self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)

return all_ids
10 changes: 6 additions & 4 deletions private_gpt/components/vector_store/vector_store_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from injector import inject, singleton
from llama_index import VectorStoreIndex
from llama_index.indices.vector_store import VectorIndexRetriever
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.types import VectorStore

from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path

Expand Down Expand Up @@ -36,14 +36,16 @@ class VectorStoreComponent:

@inject
def __init__(self) -> None:
db = chromadb.PersistentClient(
chroma_client = chromadb.PersistentClient(
path=str((local_data_path / "chroma_db").absolute())
)
chroma_collection = db.get_or_create_collection(
chroma_collection = chroma_client.get_or_create_collection(
"make_this_parameterizable_per_api_call"
) # TODO

self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.vector_store = BatchedChromaVectorStore(
chroma_client=chroma_client, chroma_collection=chroma_collection
)

@staticmethod
def get_retriever(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ description = "Private GPT"
authors = ["Zylon <hi@zylon.ai>"]

[tool.poetry.dependencies]
python = ">=3.11,<3.13"
python = ">=3.11,<3.12"
fastapi = { extras = ["all"], version = "^0.103.1" }
loguru = "^0.7.2"
boto3 = "^1.28.56"
injector = "^0.21.0"
pyyaml = "^6.0.1"
python-multipart = "^0.0.6"
pypdf = "^3.16.2"
llama-index = "v0.8.35"
llama-index = "0.8.47"
chromadb = "^0.4.13"
watchdog = "^3.0.0"
transformers = "^4.34.0"
Expand Down
27 changes: 27 additions & 0 deletions tests/server/ingest/test_ingest_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from unittest.mock import PropertyMock, patch

from llama_index import Document

from private_gpt.server.ingest.ingest_service import IngestService
from tests.fixtures.mock_injector import MockInjector


def test_save_many_nodes(injector: MockInjector) -> None:
"""This is a specific test for a local Chromadb Vector Database setup.

Extend it when we add support for other vector databases in VectorStoreComponent.
"""
with patch(
"chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock
) as max_batch_size:
# Make max batch size of Chromadb very small
max_batch_size.return_value = 10

ingest_service = injector.get(IngestService)

documents = []
for _i in range(100):
documents.append(Document(text="This is a sentence."))

ingested_docs = ingest_service._save_docs(documents)
assert len(ingested_docs) == len(documents)