-
Notifications
You must be signed in to change notification settings - Fork 8.2k
Improve RetrieveChat #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
d8af774
Upsert in batch
thinkall 8750c10
Improve update context, support customized answer prefix
thinkall ed15da5
Update tests
thinkall 6bd0a0b
Update intermediate answer
thinkall d1d87b9
Fix duplicate intermediate answer, add example 6 to notebook
thinkall e288e30
Add notebook results
thinkall b9c82b6
Works better without intermediate answers in the context
thinkall d067816
Merge branch 'main' into improve_retrieve
thinkall db1c9b3
Bump version to 0.1.2
thinkall 596ff65
Remove commented code and add descriptions to _generate_retrieve_user…
thinkall 199a38a
Merge branch 'main' into improve_retrieve
qingyun-wu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import re | ||
| import chromadb | ||
| from autogen.agentchat.agent import Agent | ||
| from autogen.agentchat import UserProxyAgent | ||
|
|
@@ -122,6 +123,9 @@ def __init__( | |
| can be found at `https://www.sbert.net/docs/pretrained_models.html`. The default model is a | ||
| fast model. If you want to use a high performance model, `all-mpnet-base-v2` is recommended. | ||
| - customized_prompt (Optional, str): the customized prompt for the retrieve chat. Default is None. | ||
| - customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "". | ||
| If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered. | ||
| - no_update_context (Optional, bool): if True, will not apply `Update Context` for interactive retrieval. Default is False. | ||
| **kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__). | ||
| """ | ||
| super().__init__( | ||
|
|
@@ -143,11 +147,16 @@ def __init__( | |
| self._must_break_at_empty_line = self._retrieve_config.get("must_break_at_empty_line", True) | ||
| self._embedding_model = self._retrieve_config.get("embedding_model", "all-MiniLM-L6-v2") | ||
| self.customized_prompt = self._retrieve_config.get("customized_prompt", None) | ||
| self.customized_answer_prefix = self._retrieve_config.get("customized_answer_prefix", "").upper() | ||
| self.no_update_context = self._retrieve_config.get("no_update_context", False) | ||
| self._context_max_tokens = self._max_tokens * 0.8 | ||
| self._collection = False # the collection is not created | ||
| self._ipython = get_ipython() | ||
| self._doc_idx = -1 # the index of the current used doc | ||
| self._results = {} # the results of the current query | ||
| self._intermediate_answers = set() # the intermediate answers | ||
| self._doc_contents = [] # the contents of the current used doc | ||
| self._doc_ids = [] # the ids of the current used doc | ||
| self.register_reply(Agent, RetrieveUserProxyAgent._generate_retrieve_user_reply) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -161,17 +170,24 @@ def get_max_tokens(model="gpt-3.5-turbo"): | |
| else: | ||
| return 4000 | ||
|
|
||
| def _reset(self): | ||
| def _reset(self, intermediate=False): | ||
| self._doc_idx = -1 # the index of the current used doc | ||
| self._results = {} # the results of the current query | ||
| if not intermediate: | ||
| self._intermediate_answers = set() # the intermediate answers | ||
| self._doc_contents = [] # the contents of the current used doc | ||
| self._doc_ids = [] # the ids of the current used doc | ||
|
|
||
| def _get_context(self, results): | ||
| doc_contents = "" | ||
| current_tokens = 0 | ||
| _doc_idx = self._doc_idx | ||
| _tmp_retrieve_count = 0 | ||
| for idx, doc in enumerate(results["documents"][0]): | ||
| if idx <= _doc_idx: | ||
| continue | ||
| if results["ids"][0][idx] in self._doc_ids: | ||
| continue | ||
| _doc_tokens = num_tokens_from_text(doc) | ||
| if _doc_tokens > self._context_max_tokens: | ||
| func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context." | ||
|
|
@@ -185,14 +201,19 @@ def _get_context(self, results): | |
| current_tokens += _doc_tokens | ||
| doc_contents += doc + "\n" | ||
| self._doc_idx = idx | ||
| self._doc_ids.append(results["ids"][0][idx]) | ||
| self._doc_contents.append(doc) | ||
| _tmp_retrieve_count += 1 | ||
| if _tmp_retrieve_count >= self.n_results: | ||
| break | ||
| return doc_contents | ||
|
|
||
| def _generate_message(self, doc_contents, task="default"): | ||
| if not doc_contents: | ||
| print(colored("No more context, will terminate.", "green"), flush=True) | ||
| return "TERMINATE" | ||
| if self.customized_prompt: | ||
| message = self.customized_prompt + "\nUser's question is: " + self.problem + "\nContext is: " + doc_contents | ||
| message = self.customized_prompt.format(input_question=self.problem, input_context=doc_contents) | ||
| elif task.upper() == "CODE": | ||
| message = PROMPT_CODE.format(input_question=self.problem, input_context=doc_contents) | ||
| elif task.upper() == "QA": | ||
|
|
@@ -209,24 +230,64 @@ def _generate_retrieve_user_reply( | |
| sender: Optional[Agent] = None, | ||
| config: Optional[Any] = None, | ||
| ) -> Tuple[bool, Union[str, Dict, None]]: | ||
| """In this function, we will update the context and reset the conversation based on different conditions. | ||
| We'll update the context and reset the conversation if no_update_context is False and either of the following: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'll be easier to think of "update_context is True" than "no_update_context is False". |
||
| (1) the last message contains "UPDATE CONTEXT", | ||
| (2) the last message doesn't contain "UPDATE CONTEXT" and the customized_answer_prefix is not in the message. | ||
| """ | ||
| if config is None: | ||
| config = self | ||
| if messages is None: | ||
| messages = self._oai_messages[sender] | ||
| message = messages[-1] | ||
| if ( | ||
| update_context_case1 = ( | ||
| "UPDATE CONTEXT" in message.get("content", "")[-20:].upper() | ||
| or "UPDATE CONTEXT" in message.get("content", "")[:20].upper() | ||
| ): | ||
| ) | ||
| update_context_case2 = ( | ||
| self.customized_answer_prefix and self.customized_answer_prefix not in message.get("content", "").upper() | ||
| ) | ||
| if (update_context_case1 or update_context_case2) and not self.no_update_context: | ||
| print(colored("Updating context and resetting conversation.", "green"), flush=True) | ||
| # extract the first sentence in the response as the intermediate answer | ||
| _message = message.get("content", "").split("\n")[0].strip() | ||
| _intermediate_info = re.split(r"(?<=[.!?])\s+", _message) | ||
| self._intermediate_answers.add(_intermediate_info[0]) | ||
|
|
||
| if update_context_case1: | ||
| # try to get more context from the current retrieved doc results because the results may be too long to fit | ||
| # in the LLM context. | ||
| doc_contents = self._get_context(self._results) | ||
|
|
||
| # Always use self.problem as the query text to retrieve docs, but each time we replace the context with the | ||
| # next similar docs in the retrieved doc results. | ||
| if not doc_contents: | ||
| for _tmp_retrieve_count in range(1, 5): | ||
| self._reset(intermediate=True) | ||
| self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1)) | ||
| doc_contents = self._get_context(self._results) | ||
| if doc_contents: | ||
| break | ||
| elif update_context_case2: | ||
| # Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar | ||
| # docs in the retrieved doc results to the context. | ||
| for _tmp_retrieve_count in range(5): | ||
| self._reset(intermediate=True) | ||
| self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1)) | ||
| self._get_context(self._results) | ||
| doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers) | ||
| if doc_contents: | ||
| break | ||
|
|
||
| self.clear_history() | ||
| sender.clear_history() | ||
| doc_contents = self._get_context(self._results) | ||
| return True, self._generate_message(doc_contents, task=self._task) | ||
| return False, None | ||
| else: | ||
| return False, None | ||
|
|
||
| def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): | ||
| if not self._collection: | ||
| print("Trying to create collection.") | ||
| create_vector_db_from_dir( | ||
| dir_path=self._docs_path, | ||
| max_tokens=self._chunk_token_size, | ||
|
|
@@ -263,6 +324,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string | |
| self._reset() | ||
| self.retrieve_docs(problem, n_results, search_string) | ||
| self.problem = problem | ||
| self.n_results = n_results | ||
| doc_contents = self._get_context(self._results) | ||
| message = self._generate_message(doc_contents, self._task) | ||
| return message | ||
|
|
@@ -278,21 +340,6 @@ def run_code(self, code, **kwargs): | |
| if self._ipython is None or lang != "python": | ||
| return super().run_code(code, **kwargs) | ||
| else: | ||
| # # capture may not work as expected | ||
| # result = self._ipython.run_cell("%%capture --no-display cap\n" + code) | ||
| # log = self._ipython.ev("cap.stdout") | ||
| # log += self._ipython.ev("cap.stderr") | ||
| # if result.result is not None: | ||
| # log += str(result.result) | ||
| # exitcode = 0 if result.success else 1 | ||
| # if result.error_before_exec is not None: | ||
| # log += f"\n{result.error_before_exec}" | ||
| # exitcode = 1 | ||
| # if result.error_in_exec is not None: | ||
| # log += f"\n{result.error_in_exec}" | ||
| # exitcode = 1 | ||
| # return exitcode, log, None | ||
|
|
||
| result = self._ipython.run_cell(code) | ||
| log = str(result.result) | ||
| exitcode = 0 if result.success else 1 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.