Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self,
name="RetrieveChatAgent", # default set to RetrieveChatAgent
human_input_mode: Optional[str] = "ALWAYS",
vector_database: Optional[str] = "chromadb", # Add vector_database parameter
retrieve_config: Optional[Dict] = None, # config for the retrieve agent
**kwargs,
):
Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(
)

self._retrieve_config = {} if retrieve_config is None else retrieve_config
self._retrieve_config["vector_database"] = vector_database # Add vector_database to retrieve_config
self._task = self._retrieve_config.get("task", "default")
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", "./docs")
Expand Down
103 changes: 78 additions & 25 deletions autogen/retrieve_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import chromadb.utils.embedding_functions as ef
import logging
import pypdf

from typing import List, Dict

logger = logging.getLogger(__name__)
TEXT_FORMATS = [
Expand Down Expand Up @@ -251,20 +251,12 @@ def is_url(string: str):
return all([result.scheme, result.netloc])
except ValueError:
return False



def create_vector_db_from_dir(
dir_path: str,
max_tokens: int = 4000,
client: API = None,
db_path: str = "/tmp/chromadb.db",
collection_name: str = "all-my-documents",
get_or_create: bool = False,
chunk_mode: str = "multi_lines",
must_break_at_empty_line: bool = True,
embedding_model: str = "all-MiniLM-L6-v2",
):
"""Create a vector db from all the files in a given directory."""
# Define separate functions for each vector database
def create_chromadb_from_dir(dir_path, max_tokens, client, db_path, collection_name, get_or_create, chunk_mode,
must_break_at_empty_line, embedding_model):
"""Create a ChromaDB from all the files in a given directory."""
if client is None:
client = chromadb.PersistentClient(path=db_path)
try:
Expand All @@ -279,6 +271,9 @@ def create_vector_db_from_dir(
metadata={"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}, # ip, l2, cosine
)

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
logger.info(f"Found {len(chunks)} chunks.")

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
logger.info(f"Found {len(chunks)} chunks.")
# Upsert in batch of 40000 or less if the total number of chunks is less than 40000
Expand All @@ -291,17 +286,8 @@ def create_vector_db_from_dir(
except ValueError as e:
logger.warning(f"{e}")


def query_vector_db(
query_texts: List[str],
n_results: int = 10,
client: API = None,
db_path: str = "/tmp/chromadb.db",
collection_name: str = "all-my-documents",
search_string: str = "",
embedding_model: str = "all-MiniLM-L6-v2",
) -> Dict[str, List[str]]:
"""Query a vector db."""
def query_chromadb(query_texts, n_results, client, db_path, collection_name, search_string, embedding_model):
"""Query a ChromaDB."""
if client is None:
client = chromadb.PersistentClient(path=db_path)
# the collection's embedding function is always the default one, but we want to use the one we used to create the
Expand All @@ -316,3 +302,70 @@ def query_vector_db(
where_document={"$contains": search_string} if search_string else None, # optional filter
)
return results

def create_lancedb_from_dir(dir_path, max_tokens, db_path, table_name, chunk_mode,
must_break_at_empty_line, embedding_model_name):
"""Create a LanceDB from all the files in a given directory."""
db = LanceDB.connect(db_path)
try:
# Load embedding model
#embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name, model_kwargs={'device': 'cpu'})
embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model)
# Initialize your embedding function (replace it with your actual embedding module)

table = db.get_table(table_name)

chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
print(f"Found {len(chunks)} chunks.")

for i, chunk in enumerate(chunks):
Copy link
Collaborator

@thinkall thinkall Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to insert in batch to speed up the insertion process. Btw, is there a upsert method for lancedb?

embeddings = embedding_function.generate_embeddings([chunk]) # Compute embeddings for the chunk
document = {
"vector": embeddings[0], # Get the embedding for the single chunk
"text": chunk,
"id": f"doc_{i}",
}
table.insert(document)
except ValueError as e:
logger.warning(f"{e}")

def query_lancedb(query_texts, n_results, db_path, table_name, search_string, embedding_model_name):
"""Query a LanceDB."""
db = LanceDB.connect(db_path)
table = db.get_table(table_name)

# Initialize your embedding function (replace with your actual embedding module)
embedding_function = SentenceTransformerEmbeddings(name=embedding_model_name)

# Compute embeddings for the query texts
query_embeddings = embedding_function.generate_embeddings(query_texts)

# Query/search n most similar results
results = table.query(query_embeddings, n_results=n_results)

return results

# Modify existing APIs to include the vector_database parameter
def create_vector_db_from_dir(dir_path, max_tokens = max_tokens: int = 4000, client=None, db_path="/tmp/chromadb.db",
collection_name="all-my-documents", get_or_create=False, chunk_mode="multi_lines",
must_break_at_empty_line=True, embedding_model="all-MiniLM-L6-v2",
vector_database="chromadb"):
if vector_database == "chromadb":
create_chromadb_from_dir(dir_path, max_tokens, client, db_path, collection_name, get_or_create, chunk_mode,
must_break_at_empty_line, embedding_model)
elif vector_database == "lancedb":
create_lancedb_from_dir(dir_path, max_tokens, db_path, "all_documents", chunk_mode, must_break_at_empty_line,
embedding_model)
else:
raise ValueError("Invalid vector_database. Please choose 'chromadb' or 'lancedb'.")

def query_vector_db(query_texts, n_results, client=None, db_path="/tmp/chromadb.db",
collection_name="all-my-documents", search_string="", embedding_model="all-MiniLM-L6-v2",
vector_database="chromadb"):
if vector_database == "chromadb":
return query_chromadb(query_texts, n_results, client, db_path, collection_name, search_string, embedding_model)
elif vector_database == "lancedb":
return query_lancedb(query_texts, n_results, db_path, "all_documents", search_string, embedding_model)
else:
raise ValueError("Invalid vector_database. Please choose 'chromadb' or 'lancedb'.")