diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 0fc83bdb7593..1365ac5f415a 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -67,6 +67,7 @@ def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent human_input_mode: Optional[str] = "ALWAYS", + vector_database: Optional[str] = "chromadb", # Add vector_database parameter retrieve_config: Optional[Dict] = None, # config for the retrieve agent **kwargs, ): @@ -124,6 +125,7 @@ def __init__( ) self._retrieve_config = {} if retrieve_config is None else retrieve_config + self._retrieve_config["vector_database"] = vector_database # Add vector_database to retrieve_config self._task = self._retrieve_config.get("task", "default") self._client = self._retrieve_config.get("client", chromadb.Client()) self._docs_path = self._retrieve_config.get("docs_path", "./docs") diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index fbe7c28784ae..07bc60f5ff99 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -9,7 +9,7 @@ import chromadb.utils.embedding_functions as ef import logging import pypdf - +from typing import List, Dict logger = logging.getLogger(__name__) TEXT_FORMATS = [ @@ -251,20 +251,12 @@ def is_url(string: str): return all([result.scheme, result.netloc]) except ValueError: return False + - -def create_vector_db_from_dir( - dir_path: str, - max_tokens: int = 4000, - client: API = None, - db_path: str = "/tmp/chromadb.db", - collection_name: str = "all-my-documents", - get_or_create: bool = False, - chunk_mode: str = "multi_lines", - must_break_at_empty_line: bool = True, - embedding_model: str = "all-MiniLM-L6-v2", -): - """Create a vector db from all the files in a given directory.""" +# Define separate functions for each vector database +def create_chromadb_from_dir(dir_path, max_tokens, client, db_path, collection_name, get_or_create, chunk_mode, + must_break_at_empty_line, embedding_model): + """Create a ChromaDB from all the files in a given directory.""" if client is None: client = chromadb.PersistentClient(path=db_path) try: @@ -279,6 +271,9 @@ def create_vector_db_from_dir( metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine ) + chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) + logger.info(f"Found {len(chunks)} chunks.") + chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) logger.info(f"Found {len(chunks)} chunks.") # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 @@ -291,17 +286,8 @@ def create_vector_db_from_dir( except ValueError as e: logger.warning(f"{e}") - -def query_vector_db( - query_texts: List[str], - n_results: int = 10, - client: API = None, - db_path: str = "/tmp/chromadb.db", - collection_name: str = "all-my-documents", - search_string: str = "", - embedding_model: str = "all-MiniLM-L6-v2", -) -> Dict[str, List[str]]: - """Query a vector db.""" +def query_chromadb(query_texts, n_results, client, db_path, collection_name, search_string, embedding_model): + """Query a ChromaDB.""" if client is None: client = chromadb.PersistentClient(path=db_path) # the collection's embedding function is always the default one, but we want to use the one we used to create the @@ -316,3 +302,70 @@ def query_vector_db( where_document={"$contains": search_string} if search_string else None, # optional filter ) return results + +def create_lancedb_from_dir(dir_path, max_tokens, db_path, table_name, chunk_mode, + must_break_at_empty_line, embedding_model_name): + """Create a LanceDB from all the files in a given directory.""" + db = LanceDB.connect(db_path) + try: + # Load embedding model + #embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs={'device': 'cpu'}) + embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model) + # Initialize your embedding function (replace it with your actual embedding module) + + table = db.get_table(table_name) + + chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) + print(f"Found {len(chunks)} chunks.") + + for i, chunk in enumerate(chunks): + embeddings = embedding_function.generate_embeddings([chunk]) # Compute embeddings for the chunk + document = { + "vector": embeddings[0], # Get the embedding for the single chunk + "text": chunk, + "id": f"doc_{i}", + } + table.insert(document) + except ValueError as e: + logger.warning(f"{e}") + +def query_lancedb(query_texts, n_results, db_path, table_name, search_string, embedding_model_name): + """Query a LanceDB.""" + db = LanceDB.connect(db_path) + table = db.get_table(table_name) + + # Initialize your embedding function (replace with your actual embedding module) + embedding_function = SentenceTransformerEmbeddings(name=embedding_model_name) + + # Compute embeddings for the query texts + query_embeddings = embedding_function.generate_embeddings(query_texts) + + # Query/search n most similar results + results = table.query(query_embeddings, n_results=n_results) + + return results + +# Modify existing APIs to include the vector_database parameter +def create_vector_db_from_dir(dir_path, max_tokens = max_tokens: int = 4000, client=None, db_path="/tmp/chromadb.db", + collection_name="all-my-documents", get_or_create=False, chunk_mode="multi_lines", + must_break_at_empty_line=True, embedding_model="all-MiniLM-L6-v2", + vector_database="chromadb"): + if vector_database == "chromadb": + create_chromadb_from_dir(dir_path, max_tokens, client, db_path, collection_name, get_or_create, chunk_mode, + must_break_at_empty_line, embedding_model) + elif vector_database == "lancedb": + create_lancedb_from_dir(dir_path, max_tokens, db_path, "all_documents", chunk_mode, must_break_at_empty_line, + embedding_model) + else: + raise ValueError("Invalid vector_database. Please choose 'chromadb' or 'lancedb'.") + +def query_vector_db(query_texts, n_results, client=None, db_path="/tmp/chromadb.db", + collection_name="all-my-documents", search_string="", embedding_model="all-MiniLM-L6-v2", + vector_database="chromadb"): + if vector_database == "chromadb": + return query_chromadb(query_texts, n_results, client, db_path, collection_name, search_string, embedding_model) + elif vector_database == "lancedb": + return query_lancedb(query_texts, n_results, db_path, "all_documents", search_string, embedding_model) + else: + raise ValueError("Invalid vector_database. Please choose 'chromadb' or 'lancedb'.") +