From a57ae36aea008e7b8777096f5b192d70ec1800d9 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Mon, 9 Oct 2023 10:09:35 +0800 Subject: [PATCH 01/10] Add custom embedding function --- .../contrib/retrieve_user_proxy_agent.py | 6 +++ autogen/retrieve_utils.py | 52 +++++++++++++++++-- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 11193a91e042..3aed757b9acd 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -122,6 +122,9 @@ def __init__( If key not provided, a default model `all-MiniLM-L6-v2` will be used. All available models can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended. + - embedding_function (Optional, Callable): the embedding function for creating the vector db. Default is None, + SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or + other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. @@ -148,6 +151,7 @@ def __init__( self._chunk_mode = self._retrieve_config.get("chunk_mode", "multi_lines") self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True) self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2") + self._embedding_function = self._retrieve_config.get("embedding_function", None) self.customized_prompt = self._retrieve_config.get("customized_prompt", None) self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) @@ -300,6 +304,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = must_break_at_empty_line=self._must_break_at_empty_line, embedding_model=self._embedding_model, get_or_create=self._get_or_create, + embedding_function=self._embedding_function, ) self._collection = True self._get_or_create = False @@ -311,6 +316,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = client=self._client, collection_name=self._collection_name, embedding_model=self._embedding_model, + embedding_function=self._embedding_function, ) self._results = results print("doc_ids: ", results["ids"]) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index ecca2f2b0bfe..b9677e8c079f 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -1,4 +1,4 @@ -from typing import List, Union, Dict, Tuple +from typing import List, Union, Dict, Tuple, Callable import os import requests from urllib.parse import urlparse @@ -246,12 +246,36 @@ def create_vector_db_from_dir( chunk_mode: str = "multi_lines", must_break_at_empty_line: bool = True, embedding_model: str = "all-MiniLM-L6-v2", + embedding_function: Callable = None, ): - """Create a vector db from all the files in a given directory.""" + """Create a vector db from all the files in a given directory, the directory can also be a single file or a url to + a single file. We support chromadb compatible APIs to create the vector db, this function is not required if + you prepared your own vector db. + + Args: + dir_path (str): the path to the directory, file or url. + max_tokens (Optional, int): the maximum number of tokens per chunk. Default is 4000. + client (Optional, API): the chromadb client. Default is None. + db_path (Optional, str): the path to the chromadb. Default is "/tmp/chromadb.db". + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". + get_or_create (Optional, bool): Whether to get or create the collection. Default is False. If True, the collection + will be recreated if it already exists. + chunk_mode (Optional, str): the chunk mode. Default is "multi_lines". + must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True. + embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if + embedding_function is not None. + embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with + the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding + functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + """ if client is None: client = chromadb.PersistentClient(path=db_path) try: - embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model) + embedding_function = ( + ef.SentenceTransformerEmbeddingFunction(embedding_model) + if embedding_function is None + else embedding_function + ) collection = client.create_collection( collection_name, get_or_create=get_or_create, @@ -283,14 +307,32 @@ def query_vector_db( collection_name: str = "all-my-documents", search_string: str = "", embedding_model: str = "all-MiniLM-L6-v2", + embedding_function: Callable = None, ) -> Dict[str, List[str]]: - """Query a vector db.""" + """Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db + and query function. + + Args: + query_texts (List[str]): the query texts. + n_results (Optional, int): the number of results to return. Default is 10. + client (Optional, API): the chromadb compatible client. Default is None, a chromadb client will be used. + db_path (Optional, str): the path to the vector db. Default is "/tmp/chromadb.db". + collection_name (Optional, str): the name of the collection. Default is "all-my-documents". + search_string (Optional, str): the search string. Default is "". + embedding_model (Optional, str): the embedding model to use. Default is "all-MiniLM-L6-v2". Will be ignored if + embedding_function is not None. + embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with + the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding + functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + """ if client is None: client = chromadb.PersistentClient(path=db_path) # the collection's embedding function is always the default one, but we want to use the one we used to create the # collection. So we compute the embeddings ourselves and pass it to the query function. collection = client.get_collection(collection_name) - embedding_function = ef.SentenceTransformerEmbeddingFunction(embedding_model) + embedding_function = ( + ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function + ) query_embeddings = embedding_function(query_texts) # Query/search n most similar results. You can also .get by id results = collection.query( From e57f2b9c1783d9060937b8ea14e2b08daecb7595 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Mon, 9 Oct 2023 11:13:32 +0800 Subject: [PATCH 02/10] Add support to custom vector db --- .../contrib/retrieve_user_proxy_agent.py | 30 +++++++++++++++---- autogen/retrieve_utils.py | 12 +++++++- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 3aed757b9acd..71a6f97dbd51 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -105,7 +105,7 @@ def __init__( - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()` will be used. - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, - or the url to a single file. If key not provided, a default path `./docs` will be used. + or the url to a single file. Default is None, which works only if the collection is already created. - collection_name (Optional, str): the name of the collection. If key not provided, a default name `autogen-docs` will be used. - model (Optional, str): the model to use for the retrieve chat. @@ -130,7 +130,7 @@ def __init__( If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. - This is the same as that used in chromadb. Default is False. + This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None. **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). """ super().__init__( @@ -143,7 +143,7 @@ def __init__( self._retrieve_config = {} if retrieve_config is None else retrieve_config self._task = self._retrieve_config.get("task", "default") self._client = self._retrieve_config.get("client", chromadb.Client()) - self._docs_path = self._retrieve_config.get("docs_path", "./docs") + self._docs_path = self._retrieve_config.get("docs_path", None) self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") self._model = self._retrieve_config.get("model", "gpt-4") self._max_tokens = self.get_max_tokens(self._model) @@ -155,9 +155,11 @@ def __init__( self.customized_prompt = self._retrieve_config.get("customized_prompt", None) self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() self.update_context = self._retrieve_config.get("update_context", True) - self._get_or_create = self._retrieve_config.get("get_or_create", False) + self._get_or_create = ( + self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False + ) self._context_max_tokens = self._max_tokens * 0.8 - self._collection = False # the collection is not created + self._collection = True if self._docs_path is None else False # whether the collection is created self._ipython = get_ipython() self._doc_idx = -1 # the index of the current used doc self._results = {} # the results of the current query @@ -185,7 +187,7 @@ def _reset(self, intermediate=False): self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - def _get_context(self, results): + def _get_context(self, results: Dict[str, Union[List[str], List[List[str]]]]): doc_contents = "" current_tokens = 0 _doc_idx = self._doc_idx @@ -293,6 +295,22 @@ def _generate_retrieve_user_reply( return False, None def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + """Retrieve docs based on the given problem and assign the results to the class property `_results`. + In case you want to customize the retrieval process, such as using a different vector db whose APIs are not + compatible with chromadb or filter results with metadata, you can override this function. Just keep the current + parameters and add your own parameters with default values, and keep the results in below type. + + Type of the results: Dict[str, List[List[Any]]] + ids: List[string] + documents: List[List[string]] + metadatas: Optional[List[List[string]]] + distances: Optional[List[List[float]]] + + Args: + problem (str): the problem to be solved. + n_results (int): the number of results to be retrieved. + search_string (str): only docs containing this string will be retrieved. + """ if not self._collection or self._get_or_create: print("Trying to create collection.") create_vector_db_from_dir( diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index b9677e8c079f..85a4c0dd024f 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -6,6 +6,7 @@ import tiktoken import chromadb from chromadb.api import API +from chromadb.api.types import QueryResult import chromadb.utils.embedding_functions as ef import logging import pypdf @@ -308,7 +309,7 @@ def query_vector_db( search_string: str = "", embedding_model: str = "all-MiniLM-L6-v2", embedding_function: Callable = None, -) -> Dict[str, List[str]]: +) -> QueryResult: """Query a vector db. We support chromadb compatible APIs, it's not required if you prepared your own vector db and query function. @@ -324,6 +325,15 @@ def query_vector_db( embedding_function (Optional, Callable): the embedding function to use. Default is None, SentenceTransformer with the given `embedding_model` will be used. If you want to use OpenAI, Cohere, HuggingFace or other embedding functions, you can pass it here, follow the examples in `https://docs.trychroma.com/embeddings`. + + Returns: + QueryResult: the query result. The format is: + class QueryResult(TypedDict): + ids: List[IDs] + embeddings: Optional[List[List[Embedding]]] + documents: Optional[List[List[Document]]] + metadatas: Optional[List[List[Metadata]]] + distances: Optional[List[List[float]]] """ if client is None: client = chromadb.PersistentClient(path=db_path) From 31f364d9947fbfdc2dd4537cfc63ef6d5b1e4204 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Mon, 9 Oct 2023 12:04:35 +0800 Subject: [PATCH 03/10] Improve docstring --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 51f196482566..a865216b80de 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -86,8 +86,8 @@ def __init__( To use default config, set to None. Otherwise, set to a dictionary with the following keys: - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System prompt will be different for different tasks. The default value is `default`, which supports both code and qa. - - client (Optional, chromadb.Client): the chromadb client. - If key not provided, a default client `chromadb.Client()` will be used. + - client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()` + will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function. - docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file, or the url to a single file. Default is None, which works only if the collection is already created. - collection_name (Optional, str): the name of the collection. From 73a4fb4522c69694731ad1626e7c655947d4f66c Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Mon, 9 Oct 2023 12:21:31 +0800 Subject: [PATCH 04/10] Improve docstring --- .../contrib/retrieve_user_proxy_agent.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index a865216b80de..c2bdaf9e5003 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -318,6 +318,33 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = problem (str): the problem to be solved. n_results (int): the number of results to be retrieved. search_string (str): only docs containing this string will be retrieved. + + Example of overriding this function: + If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with + below code. + ```python + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts: List[str], + n_results: int = 10, + search_string: str = "", + **kwargs, + ) -> Dict[str, Union[List[str], List[List[str]]]]: + # define your own query function here + pass + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + **kwargs, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + ``` """ if not self._collection or self._get_or_create: print("Trying to create collection.") From 73f0878ab2d6cf01415de2c95f22ed04ba601434 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Mon, 9 Oct 2023 12:26:29 +0800 Subject: [PATCH 05/10] Improve docstring --- .../contrib/retrieve_user_proxy_agent.py | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index c2bdaf9e5003..64b192ccd5c5 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -119,6 +119,32 @@ def __init__( The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name). Default is None, tiktoken will be used and may not be accurate for non-OpenAI models. **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). + + Example of overriding retrieve_docs: + If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with below code. + ```python + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts: List[str], + n_results: int = 10, + search_string: str = "", + **kwargs, + ) -> Dict[str, Union[List[str], List[List[str]]]]: + # define your own query function here + pass + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + **kwargs, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + ``` """ super().__init__( name=name, @@ -318,33 +344,6 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = problem (str): the problem to be solved. n_results (int): the number of results to be retrieved. search_string (str): only docs containing this string will be retrieved. - - Example of overriding this function: - If you have set up a customized vector db, and it's not compatible with chromadb, you can easily plug in it with - below code. - ```python - class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): - def query_vector_db( - self, - query_texts: List[str], - n_results: int = 10, - search_string: str = "", - **kwargs, - ) -> Dict[str, Union[List[str], List[List[str]]]]: - # define your own query function here - pass - - def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = "", **kwargs): - results = self.query_vector_db( - query_texts=[problem], - n_results=n_results, - search_string=search_string, - **kwargs, - ) - - self._results = results - print("doc_ids: ", results["ids"]) - ``` """ if not self._collection or self._get_or_create: print("Trying to create collection.") From b7d8863047cab5fad5929aed56c87aa9b1867c70 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 10 Oct 2023 01:23:07 +0000 Subject: [PATCH 06/10] Add support to customized is_termination_msg fucntion --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 64b192ccd5c5..73fe5160c292 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -67,6 +67,7 @@ def __init__( self, name="RetrieveChatAgent", # default set to RetrieveChatAgent human_input_mode: Optional[str] = "ALWAYS", + is_termination_msg: Optional[Callable[[Dict], bool]] = None, retrieve_config: Optional[Dict] = None, # config for the retrieve agent **kwargs, ): @@ -82,6 +83,9 @@ def __init__( the number of auto reply reaches the max_consecutive_auto_reply. (3) When "NEVER", the agent will never prompt for human input. Under this mode, the conversation stops when the number of auto reply reaches the max_consecutive_auto_reply or when is_termination_msg is True. + is_termination_msg (function): a function that takes a message in the form of a dictionary + and returns a boolean value indicating if this received message is a termination message. + The dict can contain the following keys: "content", "role", "name", "function_call". retrieve_config (dict or None): config for the retrieve agent. To use default config, set to None. Otherwise, set to a dictionary with the following keys: - task (Optional, str): the task of the retrieve chat. Possible values are "code", "qa" and "default". System @@ -179,7 +183,10 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self._intermediate_answers = set() # the intermediate answers self._doc_contents = [] # the contents of the current used doc self._doc_ids = [] # the ids of the current used doc - self._is_termination_msg = self._is_termination_msg_retrievechat # update the termination message function + # update the termination message function + self._is_termination_msg = ( + self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg + ) self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply, position=1) def _is_termination_msg_retrievechat(self, message): From 8cb1bcddaa7a2cea86073275b19451290e4a6e9f Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 10 Oct 2023 02:04:35 +0000 Subject: [PATCH 07/10] Add a test for customize vector db with lancedb --- setup.py | 1 + test/test_retrieve_utils.py | 61 +++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/setup.py b/setup.py index 37c9d2d883fd..a42432eb0333 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ extras_require={ "test": [ "chromadb", + "lancedb", "coverage>=5.3", "datasets", "ipykernel", diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index fdb93d26ca8d..fb162d45ee37 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -100,6 +100,67 @@ def test_query_vector_db(self): results = query_vector_db(["autogen"], client=client) assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) + def test_custom_vector_db(self): + import lancedb + from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent + + db_path = "/tmp/lancedb" + + def create_lancedb(): + db = lancedb.connect(db_path) + data = [ + {"vector": [1.1, 1.2], "id": 1, "documents": "This is a test document spark"}, + {"vector": [0.2, 1.8], "id": 2, "documents": "This is another test document"}, + {"vector": [0.1, 0.3], "id": 3, "documents": "This is a third test document spark"}, + {"vector": [0.5, 0.7], "id": 4, "documents": "This is a fourth test document"}, + {"vector": [2.1, 1.3], "id": 5, "documents": "This is a fifth test document spark"}, + {"vector": [5.1, 8.3], "id": 6, "documents": "This is a sixth test document"}, + ] + try: + db.create_table("my_table", data) + except OSError: + pass + + class MyRetrieveUserProxyAgent(RetrieveUserProxyAgent): + def query_vector_db( + self, + query_texts, + n_results=10, + search_string="", + ): + if query_texts: + vector = [0.1, 0.3] + db = lancedb.connect(db_path) + table = db.open_table("my_table") + query = table.search(vector).where(f"documents LIKE '%{search_string}%'").limit(n_results).to_df() + return {"ids": query["id"].tolist(), "documents": query["documents"].tolist()} + + def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): + results = self.query_vector_db( + query_texts=[problem], + n_results=n_results, + search_string=search_string, + ) + + self._results = results + print("doc_ids: ", results["ids"]) + + ragragproxyagent = MyRetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "task": "qa", + "chunk_token_size": 2000, + "client": "__", + "embedding_model": "all-mpnet-base-v2", + }, + ) + + create_lancedb() + ragragproxyagent.retrieve_docs("This is a test document spark", n_results=10, search_string="spark") + assert ragragproxyagent._results["ids"] == [3, 1, 5] + if __name__ == "__main__": pytest.main() From 2c24eb5fdaa598a4844889ffece6d00c264ef8c0 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 10 Oct 2023 02:11:26 +0000 Subject: [PATCH 08/10] Fix tests --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 63ca0a254609..5e5fd186beac 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -40,7 +40,7 @@ jobs: python -m pip install --upgrade pip wheel pip install -e . python -c "import autogen" - pip install -e.[mathchat,retrievechat] datasets pytest + pip install -e.[mathchat,retrievechat,test] datasets pytest pip uninstall -y openai - name: Test with pytest if: matrix.python-version != '3.10' From 3fe2586d66418d5c40852acfe4fd6701b7e08201 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 10 Oct 2023 02:50:47 +0000 Subject: [PATCH 09/10] Add test for embedding_function --- test/agentchat/test_retrievechat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/agentchat/test_retrievechat.py b/test/agentchat/test_retrievechat.py index bde5730cbbb2..99e395de5056 100644 --- a/test/agentchat/test_retrievechat.py +++ b/test/agentchat/test_retrievechat.py @@ -12,6 +12,7 @@ ) from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db import chromadb + from chromadb.utils import embedding_functions as ef skip_test = False except ImportError: @@ -49,6 +50,7 @@ def test_retrievechat(): }, ) + sentence_transformer_ef = ef.SentenceTransformerEmbeddingFunction() ragproxyagent = RetrieveUserProxyAgent( name="ragproxyagent", human_input_mode="NEVER", @@ -58,6 +60,7 @@ def test_retrievechat(): "chunk_token_size": 2000, "model": config_list[0]["model"], "client": chromadb.PersistentClient(path="/tmp/chromadb"), + "embedding_function": sentence_transformer_ef, }, ) From 95ec946be33b63b767d491d500cb24699fe6c992 Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Tue, 10 Oct 2023 06:53:47 +0000 Subject: [PATCH 10/10] Update docstring --- autogen/agentchat/contrib/retrieve_user_proxy_agent.py | 6 +++--- test/test_retrieve_utils.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 73fe5160c292..0f29aa62d14f 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -341,11 +341,11 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = compatible with chromadb or filter results with metadata, you can override this function. Just keep the current parameters and add your own parameters with default values, and keep the results in below type. - Type of the results: Dict[str, List[List[Any]]] + Type of the results: Dict[str, List[List[Any]]], should have keys "ids" and "documents", "ids" for the ids of + the retrieved docs and "documents" for the contents of the retrieved docs. Any other keys are optional. Refer + to `chromadb.api.types.QueryResult` as an example. ids: List[string] documents: List[List[string]] - metadatas: Optional[List[List[string]]] - distances: Optional[List[List[float]]] Args: problem (str): the problem to be solved. diff --git a/test/test_retrieve_utils.py b/test/test_retrieve_utils.py index fb162d45ee37..be215facb846 100644 --- a/test/test_retrieve_utils.py +++ b/test/test_retrieve_utils.py @@ -101,7 +101,10 @@ def test_query_vector_db(self): assert isinstance(results, dict) and any("autogen" in res[0].lower() for res in results.get("documents", [])) def test_custom_vector_db(self): - import lancedb + try: + import lancedb + except ImportError: + return from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent db_path = "/tmp/lancedb"