Skip to content

Commit

Permalink
feat: test_add_or_update_documents_new_documents
Browse files Browse the repository at this point in the history
  • Loading branch information
bossjones committed Aug 20, 2024
1 parent 16d9ad6 commit 4a0a137
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 26 deletions.
24 changes: 23 additions & 1 deletion src/goob_ai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from goob_ai.services.chroma_service import ChromaService
from goob_ai.services.screencrop_service import ImageService
from goob_ai.utils import repo_typing
from goob_ai.utils.base import print_line_seperator
from goob_ai.utils.file_functions import fix_path


Expand Down Expand Up @@ -571,7 +572,28 @@ def add_and_query(collection_name: str, question: str, reset: bool = False) -> N
# import bpdb

# bpdb.set_trace()
ChromaService.get_response(question, collection_name=collection_name)
resp = ChromaService.get_response(question, collection_name=collection_name)
print_line_seperator("test")
rich.print(resp)
except Exception as ex:
print(f"{ex}")
exc_type, exc_value, exc_traceback = sys.exc_info()
print(f"Error Class: {ex.__class__}")
output = f"[UNEXPECTED] {type(ex).__name__}: {ex}"
print(output)
print(f"exc_type: {exc_type}")
print(f"exc_value: {exc_value}")
traceback.print_tb(exc_traceback)

Check warning on line 586 in src/goob_ai/cli.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/cli.py#L575-L586

Added lines #L575 - L586 were not covered by tests
if aiosettings.dev_mode:
bpdb.pm()

Check warning on line 588 in src/goob_ai/cli.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/cli.py#L588

Added line #L588 was not covered by tests


@APP.command()
def ask(collection_name: str, question: str) -> None:
"""Ask vectorstore"""
try:
resp = ChromaService.get_response(question, collection_name=collection_name)
rich.print(resp)
except Exception as ex:
print(f"{ex}")
exc_type, exc_value, exc_traceback = sys.exc_info()
Expand Down
119 changes: 94 additions & 25 deletions src/goob_ai/services/chroma_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,36 +404,97 @@ def calculate_chunk_ids(chunks: list[Document]) -> list[Document]:

# SOURCE: https://github.com/divyeg/meakuchatbot_project/blob/0c4483ce4bebce923233cf2a1139f089ac5d9e53/createVectorDB.ipynb#L203
# TODO: Enable and refactor this function
# def add_or_update_documents(db: Chroma, documents: list[Document]) -> None:
# from langchain_community.embeddings import HuggingFaceEmbeddings
@pysnooper.snoop()
def add_or_update_documents(
persist_directory: str = CHROMA_PATH,
disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = (),
use_custom_openai_embeddings: bool = False,
collection_name: str = "",
path_to_document: str = "",
embedding_function: Any | None = OpenAIEmbeddings(),
) -> None:
"""
Add or update documents in a Chroma database.
This function loads documents from the specified path, splits them into chunks if necessary,
and adds or updates them in the Chroma database. It uses the appropriate loader and text splitter
based on the file type of the document.
Args:
persist_directory (str): The directory where the Chroma database is persisted. Defaults to CHROMA_PATH.
disallowed_special (Union[Literal["all"], set[str], Sequence[str], None]): Special characters to disallow in the embeddings. Defaults to an empty tuple.
use_custom_openai_embeddings (bool): Whether to use custom OpenAI embeddings. Defaults to False.
collection_name (str): The name of the collection in the Chroma database. Defaults to an empty string.
path_to_document (str): The path to the document to be added or updated. Defaults to an empty string.
embedding_function (Any | None): The embedding function to use. Defaults to OpenAIEmbeddings().
Returns:
None
"""

# embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# NOTE: orig code
# from langchain_community.embeddings import HuggingFaceEmbeddings
# embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedder)
# Log the input parameters for debugging purposes
LOGGER.debug(f"path_to_document = {path_to_document}")
LOGGER.debug(f"collection_name = {collection_name}")
LOGGER.debug(f"embedding_function = {embedding_function}")

Check warning on line 442 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L440-L442

Added lines #L440 - L442 were not covered by tests

# last_request_time = 0
# RATE_LIMIT_INTERVAL = 10
db = get_chroma_db(persist_directory, embedding_function, collection_name=collection_name)

Check warning on line 444 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L444

Added line #L444 was not covered by tests

# Get the Chroma client
client = ChromaService.get_client()

Check warning on line 447 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L447

Added line #L447 was not covered by tests
# FIXME: We need to make embedding_function optional
# Add or retrieve the collection with the specified name
collection: chromadb.Collection = ChromaService.add_collection(collection_name)

Check warning on line 450 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L450

Added line #L450 was not covered by tests

# Load the document using the appropriate loader based on the file type
loader: TextLoader | PyMuPDFLoader | WebBaseLoader | None = get_rag_loader(path_to_document)

Check warning on line 453 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L453

Added line #L453 was not covered by tests
# Load the documents using the selected loader
documents: list[Document] = loader.load()

Check warning on line 455 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L455

Added line #L455 was not covered by tests

# If the file type is txt, split the documents into chunks
text_splitter = get_rag_splitter(path_to_document)

Check warning on line 458 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L458

Added line #L458 was not covered by tests
if text_splitter:
# Split the documents into chunks using the text splitter
chunks: list[Document] = text_splitter.split_documents(documents)

Check warning on line 461 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L461

Added line #L461 was not covered by tests
else:
# If no text splitter is available, use the original documents
chunks: list[Document] = documents # type: ignore

Check warning on line 464 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L464

Added line #L464 was not covered by tests

#################################

# embedder = OpenAIEmbeddings(openai_api_key=aiosettings.openai_api_key.get_secret_value())

# db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedder)

last_request_time = 0
RATE_LIMIT_INTERVAL = 10

Check warning on line 473 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L472-L473

Added lines #L472 - L473 were not covered by tests

# chunks_with_ids = calculate_chunk_ids(chunks)
chunks_with_ids: list[Document] = calculate_chunk_ids(chunks)

Check warning on line 475 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L475

Added line #L475 was not covered by tests

# # Add or Update the documents.
# existing_items = db.get(include=[]) # IDs are always included by default
# existing_ids = set(existing_items["ids"])
# print(f"Number of existing documents in DB: {len(existing_ids)}")
# Add or Update the documents.
existing_items = db.get(include=[]) # IDs are always included by default
existing_ids = set(existing_items["ids"])
LOGGER.info(f"Number of existing documents in DB: {len(existing_ids)}")

Check warning on line 480 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L478-L480

Added lines #L478 - L480 were not covered by tests

# Only add documents that don't exist in the DB.
new_chunks = []

Check warning on line 483 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L483

Added line #L483 was not covered by tests
for chunk in chunks_with_ids:
if chunk.metadata["id"] not in existing_ids:
new_chunks.append(chunk)

Check warning on line 486 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L486

Added line #L486 was not covered by tests

if len(new_chunks):
LOGGER.info(f"Adding new documents: {len(new_chunks)}")

Check warning on line 489 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L489

Added line #L489 was not covered by tests
new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
docs_added = db.add_documents(new_chunks, ids=new_chunk_ids)
LOGGER.info(f"Saved {len(docs_added)} chunks to {CHROMA_PATH}.")

Check warning on line 492 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L491-L492

Added lines #L491 - L492 were not covered by tests
# db.persist()
else:
LOGGER.info("No new documents to add")

Check warning on line 495 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L495

Added line #L495 was not covered by tests

# # Only add documents that don't exist in the DB.
# new_chunks = []
# for chunk in chunks_with_ids:
# if chunk.metadata["id"] not in existing_ids:
# new_chunks.append(chunk)

# if len(new_chunks):
# print(f"Adding new documents: {len(new_chunks)}")
# new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
# db.add_documents(new_chunks, ids=new_chunk_ids)
# db.persist()
# else:
# print("No new documents to add")
# Number of existing documents in DB: 0
# Adding new documents: 10
# Saved 10 chunks to input_data/chroma.
Expand Down Expand Up @@ -629,7 +690,9 @@ def get_rag_embedding_function(
return None


def get_client(host: str = aiosettings.chroma_host, port: int = aiosettings.chroma_port) -> chromadb.ClientAPI:
def get_client(
host: str = aiosettings.chroma_host, port: int = aiosettings.chroma_port, **kwargs: Any
) -> chromadb.ClientAPI:
"""Get the ChromaDB client.
Returns:
Expand All @@ -638,7 +701,7 @@ def get_client(host: str = aiosettings.chroma_host, port: int = aiosettings.chro
return chromadb.HttpClient(
host=host,
port=port,
settings=ChromaSettings(allow_reset=True, is_persistent=True),
settings=ChromaSettings(allow_reset=True, is_persistent=True, persist_directory=CHROMA_PATH, **kwargs),
)


Expand Down Expand Up @@ -1007,8 +1070,14 @@ def save_to_chroma(
)
db = vectorstore

Check warning on line 1071 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L1071

Added line #L1071 was not covered by tests

# vectorstore.persist()

LOGGER.info(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.")

Check warning on line 1075 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L1075

Added line #L1075 was not covered by tests

# import bpdb

# bpdb.set_trace()

retriever: VectorStoreRetriever = vectorstore.as_retriever()

Check warning on line 1081 in src/goob_ai/services/chroma_service.py

View check run for this annotation

Codecov / codecov/patch

src/goob_ai/services/chroma_service.py#L1081

Added line #L1081 was not covered by tests

except Exception as ex:
Expand Down
116 changes: 116 additions & 0 deletions tests/services/test_chroma_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import random
import shutil
import tempfile

Expand All @@ -14,9 +15,11 @@
from chromadb import Collection
from goob_ai.aio_settings import aiosettings
from goob_ai.services.chroma_service import (
CHROMA_PATH,
CHROMA_PATH_API,
ChromaService,
CustomOpenAIEmbeddings,
add_or_update_documents,
calculate_chunk_ids,
compare_two_words,
create_chroma_db,
Expand Down Expand Up @@ -85,6 +88,15 @@
# assert "Comparing (apple, banana):" in caplog.text
# caplog.clear()

FIRST_NAMES = ["Ada", "Bela", "Cade", "Dax", "Eva", "Fynn", "Gia", "Hugo", "Ivy", "Jax"]
LAST_NAMES = ["Smith", "Johnson", "Williams", "Jones", "Brown", "Davis", "Miller", "Wilson", "Moore", "Taylor"]


def generate_random_name():
first_name = random.choice(FIRST_NAMES)
last_name = random.choice(LAST_NAMES)
return f"{first_name}_{last_name}".lower()


def test_calculate_chunk_ids():
"""
Expand Down Expand Up @@ -1377,3 +1389,107 @@ async def test_generate_document_hashes():

# # Assert that the result is an instance of VectorStoreRetriever
# assert isinstance(result, VectorStoreRetriever)


@pytest.fixture()
def mock_chroma_db(mocker: MockerFixture):
"""Fixture to create a mock Chroma database."""
mock_db = mocker.MagicMock()
mock_db.get.return_value = {"ids": []}
return mock_db


@pytest.fixture()
def mock_loader(mocker: MockerFixture):
"""Fixture to create a mock loader."""
mock_loader = mocker.MagicMock()
mock_loader.load.return_value = [Document(page_content="Test document")]
return mock_loader


@pytest.fixture()
def mock_text_splitter(mocker: MockerFixture):
"""Fixture to create a mock text splitter."""
mock_splitter = mocker.MagicMock()
mock_splitter.split_documents.return_value = [Document(page_content="Test chunk")]
return mock_splitter


@pytest.mark.unittest()
def test_add_or_update_documents_new_documents(
mocker: MockerFixture,
mock_chroma_db: MockerFixture,
mock_loader: MockerFixture,
mock_text_splitter: MockerFixture,
mock_pdf_file: Path,
caplog: LogCaptureFixture,
capsys: CaptureFixture,
):
"""
Test adding new documents to the Chroma database.
This test verifies that the `add_or_update_documents` function correctly adds
new documents to the Chroma database when they don't exist.
"""
caplog.set_level(logging.DEBUG)

mock_get_rag_splitter = mocker.patch("goob_ai.services.chroma_service.get_rag_splitter")
mock_get_chroma_db = mocker.patch("goob_ai.services.chroma_service.get_chroma_db")
mock_get_rag_loader = mocker.patch("goob_ai.services.chroma_service.get_rag_loader")

mock_get_chroma_db.return_value = mock_chroma_db
mock_get_rag_loader.return_value = mock_loader
mock_get_rag_splitter.return_value = mock_text_splitter

collection_name = generate_random_name()

add_or_update_documents(path_to_document=f"{mock_pdf_file}", collection_name=collection_name)

mock_chroma_db.add_documents.assert_called_once()
mock_chroma_db.get.assert_called_once_with(include=[])


@pytest.mark.unittest()
def test_add_or_update_documents_existing_documents(
mocker: MockerFixture,
mock_chroma_db: MockerFixture,
mock_loader: MockerFixture,
mock_text_splitter: MockerFixture,
mock_pdf_file: Path,
caplog: LogCaptureFixture,
capsys: CaptureFixture,
):
"""
Test adding existing documents to the Chroma database.
This test verifies that the `add_or_update_documents` function correctly skips
adding documents that already exist in the Chroma database.
"""
caplog.set_level(logging.DEBUG)

mock_get_rag_splitter = mocker.patch("goob_ai.services.chroma_service.get_rag_splitter")
mock_get_chroma_db = mocker.patch("goob_ai.services.chroma_service.get_chroma_db")
mock_get_rag_loader = mocker.patch("goob_ai.services.chroma_service.get_rag_loader")

mock_get_chroma_db.return_value = mock_chroma_db
mock_get_rag_loader.return_value = mock_loader
mock_get_rag_splitter.return_value = mock_text_splitter
mock_chroma_db.get.return_value = {"ids": ["test_chunk_id"]}

collection_name = generate_random_name()

add_or_update_documents(path_to_document=f"{mock_pdf_file}", collection_name=collection_name)

assert mock_chroma_db.add_documents.call_count == 1
assert mock_chroma_db.add_documents.call_args.kwargs == {"ids": ["None:None:0"]}
assert mock_chroma_db.add_documents.call_args.args == (
[Document(metadata={"id": "None:None:0"}, page_content="Test chunk")],
)

calls = [
mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]),
# mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]),
# mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]),
# mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]),
]
assert mock_chroma_db.add_documents.call_args_list == calls

0 comments on commit 4a0a137

Please sign in to comment.