From 4a0a1376855b5d6d09dc1d37cf761f34f5bc1431 Mon Sep 17 00:00:00 2001 From: Malcolm Jones Date: Tue, 20 Aug 2024 16:54:02 -0400 Subject: [PATCH] feat: test_add_or_update_documents_new_documents --- src/goob_ai/cli.py | 24 ++++- src/goob_ai/services/chroma_service.py | 119 +++++++++++++++++++------ tests/services/test_chroma_service.py | 116 ++++++++++++++++++++++++ 3 files changed, 233 insertions(+), 26 deletions(-) diff --git a/src/goob_ai/cli.py b/src/goob_ai/cli.py index afcf4dcb..55e529d0 100644 --- a/src/goob_ai/cli.py +++ b/src/goob_ai/cli.py @@ -69,6 +69,7 @@ from goob_ai.services.chroma_service import ChromaService from goob_ai.services.screencrop_service import ImageService from goob_ai.utils import repo_typing +from goob_ai.utils.base import print_line_seperator from goob_ai.utils.file_functions import fix_path @@ -571,7 +572,28 @@ def add_and_query(collection_name: str, question: str, reset: bool = False) -> N # import bpdb # bpdb.set_trace() - ChromaService.get_response(question, collection_name=collection_name) + resp = ChromaService.get_response(question, collection_name=collection_name) + print_line_seperator("test") + rich.print(resp) + except Exception as ex: + print(f"{ex}") + exc_type, exc_value, exc_traceback = sys.exc_info() + print(f"Error Class: {ex.__class__}") + output = f"[UNEXPECTED] {type(ex).__name__}: {ex}" + print(output) + print(f"exc_type: {exc_type}") + print(f"exc_value: {exc_value}") + traceback.print_tb(exc_traceback) + if aiosettings.dev_mode: + bpdb.pm() + + +@APP.command() +def ask(collection_name: str, question: str) -> None: + """Ask vectorstore""" + try: + resp = ChromaService.get_response(question, collection_name=collection_name) + rich.print(resp) except Exception as ex: print(f"{ex}") exc_type, exc_value, exc_traceback = sys.exc_info() diff --git a/src/goob_ai/services/chroma_service.py b/src/goob_ai/services/chroma_service.py index 42fcdff6..c90c40bd 100644 --- a/src/goob_ai/services/chroma_service.py +++ b/src/goob_ai/services/chroma_service.py @@ -404,36 +404,97 @@ def calculate_chunk_ids(chunks: list[Document]) -> list[Document]: # SOURCE: https://github.com/divyeg/meakuchatbot_project/blob/0c4483ce4bebce923233cf2a1139f089ac5d9e53/createVectorDB.ipynb#L203 # TODO: Enable and refactor this function -# def add_or_update_documents(db: Chroma, documents: list[Document]) -> None: -# from langchain_community.embeddings import HuggingFaceEmbeddings +@pysnooper.snoop() +def add_or_update_documents( + persist_directory: str = CHROMA_PATH, + disallowed_special: Union[Literal["all"], set[str], Sequence[str], None] = (), + use_custom_openai_embeddings: bool = False, + collection_name: str = "", + path_to_document: str = "", + embedding_function: Any | None = OpenAIEmbeddings(), +) -> None: + """ + Add or update documents in a Chroma database. + + This function loads documents from the specified path, splits them into chunks if necessary, + and adds or updates them in the Chroma database. It uses the appropriate loader and text splitter + based on the file type of the document. + + Args: + persist_directory (str): The directory where the Chroma database is persisted. Defaults to CHROMA_PATH. + disallowed_special (Union[Literal["all"], set[str], Sequence[str], None]): Special characters to disallow in the embeddings. Defaults to an empty tuple. + use_custom_openai_embeddings (bool): Whether to use custom OpenAI embeddings. Defaults to False. + collection_name (str): The name of the collection in the Chroma database. Defaults to an empty string. + path_to_document (str): The path to the document to be added or updated. Defaults to an empty string. + embedding_function (Any | None): The embedding function to use. Defaults to OpenAIEmbeddings(). + + Returns: + None + """ -# embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") + # NOTE: orig code + # from langchain_community.embeddings import HuggingFaceEmbeddings + # embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") -# db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedder) + # Log the input parameters for debugging purposes + LOGGER.debug(f"path_to_document = {path_to_document}") + LOGGER.debug(f"collection_name = {collection_name}") + LOGGER.debug(f"embedding_function = {embedding_function}") -# last_request_time = 0 -# RATE_LIMIT_INTERVAL = 10 + db = get_chroma_db(persist_directory, embedding_function, collection_name=collection_name) + + # Get the Chroma client + client = ChromaService.get_client() + # FIXME: We need to make embedding_function optional + # Add or retrieve the collection with the specified name + collection: chromadb.Collection = ChromaService.add_collection(collection_name) + + # Load the document using the appropriate loader based on the file type + loader: TextLoader | PyMuPDFLoader | WebBaseLoader | None = get_rag_loader(path_to_document) + # Load the documents using the selected loader + documents: list[Document] = loader.load() + + # If the file type is txt, split the documents into chunks + text_splitter = get_rag_splitter(path_to_document) + if text_splitter: + # Split the documents into chunks using the text splitter + chunks: list[Document] = text_splitter.split_documents(documents) + else: + # If no text splitter is available, use the original documents + chunks: list[Document] = documents # type: ignore + + ################################# + + # embedder = OpenAIEmbeddings(openai_api_key=aiosettings.openai_api_key.get_secret_value()) + + # db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedder) + + last_request_time = 0 + RATE_LIMIT_INTERVAL = 10 -# chunks_with_ids = calculate_chunk_ids(chunks) + chunks_with_ids: list[Document] = calculate_chunk_ids(chunks) -# # Add or Update the documents. -# existing_items = db.get(include=[]) # IDs are always included by default -# existing_ids = set(existing_items["ids"]) -# print(f"Number of existing documents in DB: {len(existing_ids)}") + # Add or Update the documents. + existing_items = db.get(include=[]) # IDs are always included by default + existing_ids = set(existing_items["ids"]) + LOGGER.info(f"Number of existing documents in DB: {len(existing_ids)}") + + # Only add documents that don't exist in the DB. + new_chunks = [] + for chunk in chunks_with_ids: + if chunk.metadata["id"] not in existing_ids: + new_chunks.append(chunk) + + if len(new_chunks): + LOGGER.info(f"Adding new documents: {len(new_chunks)}") + new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks] + docs_added = db.add_documents(new_chunks, ids=new_chunk_ids) + LOGGER.info(f"Saved {len(docs_added)} chunks to {CHROMA_PATH}.") + # db.persist() + else: + LOGGER.info("No new documents to add") -# # Only add documents that don't exist in the DB. -# new_chunks = [] -# for chunk in chunks_with_ids: -# if chunk.metadata["id"] not in existing_ids: -# new_chunks.append(chunk) -# if len(new_chunks): -# print(f"Adding new documents: {len(new_chunks)}") -# new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks] -# db.add_documents(new_chunks, ids=new_chunk_ids) -# db.persist() -# else: -# print("No new documents to add") # Number of existing documents in DB: 0 # Adding new documents: 10 # Saved 10 chunks to input_data/chroma. @@ -629,7 +690,9 @@ def get_rag_embedding_function( return None -def get_client(host: str = aiosettings.chroma_host, port: int = aiosettings.chroma_port) -> chromadb.ClientAPI: +def get_client( + host: str = aiosettings.chroma_host, port: int = aiosettings.chroma_port, **kwargs: Any +) -> chromadb.ClientAPI: """Get the ChromaDB client. Returns: @@ -638,7 +701,7 @@ def get_client(host: str = aiosettings.chroma_host, port: int = aiosettings.chro return chromadb.HttpClient( host=host, port=port, - settings=ChromaSettings(allow_reset=True, is_persistent=True), + settings=ChromaSettings(allow_reset=True, is_persistent=True, persist_directory=CHROMA_PATH, **kwargs), ) @@ -1007,8 +1070,14 @@ def save_to_chroma( ) db = vectorstore + # vectorstore.persist() + LOGGER.info(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.") + # import bpdb + + # bpdb.set_trace() + retriever: VectorStoreRetriever = vectorstore.as_retriever() except Exception as ex: diff --git a/tests/services/test_chroma_service.py b/tests/services/test_chroma_service.py index f4d561ed..c053b502 100644 --- a/tests/services/test_chroma_service.py +++ b/tests/services/test_chroma_service.py @@ -4,6 +4,7 @@ import json import logging import os +import random import shutil import tempfile @@ -14,9 +15,11 @@ from chromadb import Collection from goob_ai.aio_settings import aiosettings from goob_ai.services.chroma_service import ( + CHROMA_PATH, CHROMA_PATH_API, ChromaService, CustomOpenAIEmbeddings, + add_or_update_documents, calculate_chunk_ids, compare_two_words, create_chroma_db, @@ -85,6 +88,15 @@ # assert "Comparing (apple, banana):" in caplog.text # caplog.clear() +FIRST_NAMES = ["Ada", "Bela", "Cade", "Dax", "Eva", "Fynn", "Gia", "Hugo", "Ivy", "Jax"] +LAST_NAMES = ["Smith", "Johnson", "Williams", "Jones", "Brown", "Davis", "Miller", "Wilson", "Moore", "Taylor"] + + +def generate_random_name(): + first_name = random.choice(FIRST_NAMES) + last_name = random.choice(LAST_NAMES) + return f"{first_name}_{last_name}".lower() + def test_calculate_chunk_ids(): """ @@ -1377,3 +1389,107 @@ async def test_generate_document_hashes(): # # Assert that the result is an instance of VectorStoreRetriever # assert isinstance(result, VectorStoreRetriever) + + +@pytest.fixture() +def mock_chroma_db(mocker: MockerFixture): + """Fixture to create a mock Chroma database.""" + mock_db = mocker.MagicMock() + mock_db.get.return_value = {"ids": []} + return mock_db + + +@pytest.fixture() +def mock_loader(mocker: MockerFixture): + """Fixture to create a mock loader.""" + mock_loader = mocker.MagicMock() + mock_loader.load.return_value = [Document(page_content="Test document")] + return mock_loader + + +@pytest.fixture() +def mock_text_splitter(mocker: MockerFixture): + """Fixture to create a mock text splitter.""" + mock_splitter = mocker.MagicMock() + mock_splitter.split_documents.return_value = [Document(page_content="Test chunk")] + return mock_splitter + + +@pytest.mark.unittest() +def test_add_or_update_documents_new_documents( + mocker: MockerFixture, + mock_chroma_db: MockerFixture, + mock_loader: MockerFixture, + mock_text_splitter: MockerFixture, + mock_pdf_file: Path, + caplog: LogCaptureFixture, + capsys: CaptureFixture, +): + """ + Test adding new documents to the Chroma database. + + This test verifies that the `add_or_update_documents` function correctly adds + new documents to the Chroma database when they don't exist. + """ + caplog.set_level(logging.DEBUG) + + mock_get_rag_splitter = mocker.patch("goob_ai.services.chroma_service.get_rag_splitter") + mock_get_chroma_db = mocker.patch("goob_ai.services.chroma_service.get_chroma_db") + mock_get_rag_loader = mocker.patch("goob_ai.services.chroma_service.get_rag_loader") + + mock_get_chroma_db.return_value = mock_chroma_db + mock_get_rag_loader.return_value = mock_loader + mock_get_rag_splitter.return_value = mock_text_splitter + + collection_name = generate_random_name() + + add_or_update_documents(path_to_document=f"{mock_pdf_file}", collection_name=collection_name) + + mock_chroma_db.add_documents.assert_called_once() + mock_chroma_db.get.assert_called_once_with(include=[]) + + +@pytest.mark.unittest() +def test_add_or_update_documents_existing_documents( + mocker: MockerFixture, + mock_chroma_db: MockerFixture, + mock_loader: MockerFixture, + mock_text_splitter: MockerFixture, + mock_pdf_file: Path, + caplog: LogCaptureFixture, + capsys: CaptureFixture, +): + """ + Test adding existing documents to the Chroma database. + + This test verifies that the `add_or_update_documents` function correctly skips + adding documents that already exist in the Chroma database. + """ + caplog.set_level(logging.DEBUG) + + mock_get_rag_splitter = mocker.patch("goob_ai.services.chroma_service.get_rag_splitter") + mock_get_chroma_db = mocker.patch("goob_ai.services.chroma_service.get_chroma_db") + mock_get_rag_loader = mocker.patch("goob_ai.services.chroma_service.get_rag_loader") + + mock_get_chroma_db.return_value = mock_chroma_db + mock_get_rag_loader.return_value = mock_loader + mock_get_rag_splitter.return_value = mock_text_splitter + mock_chroma_db.get.return_value = {"ids": ["test_chunk_id"]} + + collection_name = generate_random_name() + + add_or_update_documents(path_to_document=f"{mock_pdf_file}", collection_name=collection_name) + + assert mock_chroma_db.add_documents.call_count == 1 + assert mock_chroma_db.add_documents.call_args.kwargs == {"ids": ["None:None:0"]} + assert mock_chroma_db.add_documents.call_args.args == ( + [Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], + ) + + calls = [ + mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]), + # mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]), + # mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]), + # mocker.call([Document(metadata={"id": "None:None:0"}, page_content="Test chunk")], ids=["None:None:0"]), + ] + assert mock_chroma_db.add_documents.call_args_list == calls