diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 62d9ffd93bb6..11193a91e042 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -125,7 +125,9 @@ def __init__( - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. - - no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False. + - update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True. + - get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat. + This is the same as that used in chromadb. Default is False. **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). """ super().__init__( @@ -148,7 +150,8 @@ def __init__( self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2") self.customized_prompt = self._retrieve_config.get("customized_prompt", None) self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() - self.no_update_context = self._retrieve_config.get("no_update_context", False) + self.update_context = self._retrieve_config.get("update_context", True) + self._get_or_create = self._retrieve_config.get("get_or_create", False) self._context_max_tokens = self._max_tokens * 0.8 self._collection = False # the collection is not created self._ipython = get_ipython() @@ -231,7 +234,7 @@ def _generate_retrieve_user_reply( config: Optional[Any] = None, ) -> Tuple[bool, Union[str, Dict, None]]: """In this function, we will update the context and reset the conversation based on different conditions. - We'll update the context and reset the conversation if no_update_context is False and either of the following: + We'll update the context and reset the conversation if update_context is True and either of the following: (1) the last message contains "UPDATE CONTEXT", (2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message. """ @@ -247,7 +250,7 @@ def _generate_retrieve_user_reply( update_context_case2 = ( self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper() ) - if (update_context_case1 or update_context_case2) and not self.no_update_context: + if (update_context_case1 or update_context_case2) and self.update_context: print(colored("Updating context and resetting conversation.", "green"), flush=True) # extract the first sentence in the response as the intermediate answer _message = message.get("content", "").split("\n")[0].strip() @@ -286,7 +289,7 @@ def _generate_retrieve_user_reply( return False, None def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): - if not self._collection: + if not self._collection or self._get_or_create: print("Trying to create collection.") create_vector_db_from_dir( dir_path=self._docs_path, @@ -296,8 +299,10 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = chunk_mode=self._chunk_mode, must_break_at_empty_line=self._must_break_at_empty_line, embedding_model=self._embedding_model, + get_or_create=self._get_or_create, ) self._collection = True + self._get_or_create = False results = query_vector_db( query_texts=[problem], diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index 5bb264612485..806834eb31ca 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -208,18 +208,13 @@ def create_vector_db_from_dir( chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line) print(f"Found {len(chunks)} chunks.") - # upsert in batch of 40000 - for i in range(0, len(chunks), 40000): + # Upsert in batch of 40000 or less if the total number of chunks is less than 40000 + for i in range(0, len(chunks), min(40000, len(chunks))): + end_idx = i + min(40000, len(chunks) - i) collection.upsert( - documents=chunks[ - i : i + 40000 - ], # we handle tokenization, embedding, and indexing automatically. You can skip that and add your own embeddings as well - ids=[f"doc_{i}" for i in range(i, i + 40000)], # unique for each doc + documents=chunks[i:end_idx], + ids=[f"doc_{j}" for j in range(i, end_idx)], # unique for each doc ) - collection.upsert( - documents=chunks[i : len(chunks)], - ids=[f"doc_{i}" for i in range(i, len(chunks))], # unique for each doc - ) except ValueError as e: logger.warning(f"{e}")