diff --git a/app/backend/approaches/chatreadretrieveread.py b/app/backend/approaches/chatreadretrieveread.py index 831219d8bf..dde45138c8 100644 --- a/app/backend/approaches/chatreadretrieveread.py +++ b/app/backend/approaches/chatreadretrieveread.py @@ -335,12 +335,12 @@ def get_messages_from_history( message_builder = MessageBuilder(system_prompt, model_id) # Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message. - for shot in few_shots: - message_builder.append_message(shot.get("role"), shot.get("content")) + for shot in reversed(few_shots): + message_builder.insert_message(shot.get("role"), shot.get("content")) append_index = len(few_shots) + 1 - message_builder.append_message(self.USER, user_content, index=append_index) + message_builder.insert_message(self.USER, user_content, index=append_index) total_token_count = message_builder.count_tokens_for_message(message_builder.messages[-1]) newest_to_oldest = list(reversed(history[:-1])) @@ -349,7 +349,7 @@ def get_messages_from_history( if (total_token_count + potential_message_count) > max_tokens: logging.debug("Reached max tokens of %d, history will be truncated", max_tokens) break - message_builder.append_message(message["role"], message["content"], index=append_index) + message_builder.insert_message(message["role"], message["content"], index=append_index) total_token_count += potential_message_count return message_builder.messages diff --git a/app/backend/approaches/retrievethenread.py b/app/backend/approaches/retrievethenread.py index 861023dd8e..12b70dd493 100644 --- a/app/backend/approaches/retrievethenread.py +++ b/app/backend/approaches/retrievethenread.py @@ -127,11 +127,11 @@ async def run( # add user question user_content = q + "\n" + f"Sources:\n {content}" - message_builder.append_message("user", user_content) + message_builder.insert_message("user", user_content) # Add shots/samples. This helps model to mimic response and make sure they match rules laid out in system message. - message_builder.append_message("assistant", self.answer) - message_builder.append_message("user", self.question) + message_builder.insert_message("assistant", self.answer) + message_builder.insert_message("user", self.question) messages = message_builder.messages chatgpt_args = {"deployment_id": self.chatgpt_deployment} if self.openai_host == "azure" else {} diff --git a/app/backend/core/messagebuilder.py b/app/backend/core/messagebuilder.py index b071511a65..35d0edc595 100644 --- a/app/backend/core/messagebuilder.py +++ b/app/backend/core/messagebuilder.py @@ -12,14 +12,22 @@ class MessageBuilder: token_count (int): The total number of tokens in the conversation. Methods: __init__(self, system_content: str, chatgpt_model: str): Initializes the MessageBuilder instance. - append_message(self, role: str, content: str, index: int = 1): Appends a new message to the conversation. + insert_message(self, role: str, content: str, index: int = 1): Inserts a new message to the conversation. """ def __init__(self, system_content: str, chatgpt_model: str): self.messages = [{"role": "system", "content": self.normalize_content(system_content)}] self.model = chatgpt_model - def append_message(self, role: str, content: str, index: int = 1): + def insert_message(self, role: str, content: str, index: int = 1): + """ + Inserts a message into the conversation at the specified index, + or at index 1 (after system message) if no index is specified. + Args: + role (str): The role of the message sender (either "user" or "system"). + content (str): The content of the message. + index (int): The index at which to insert the message. + """ self.messages.insert(index, {"role": role, "content": self.normalize_content(content)}) def count_tokens_for_message(self, message: dict[str, str]): diff --git a/tests/test_chatapproach.py b/tests/test_chatapproach.py index d9b5889db3..22a7e3b387 100644 --- a/tests/test_chatapproach.py +++ b/tests/test_chatapproach.py @@ -207,3 +207,27 @@ def test_extract_followup_questions_no_pre_content(): pre_content, followup_questions = chat_approach.extract_followup_questions(content) assert pre_content == "" assert followup_questions == ["What is the dress code?"] + + +def test_get_messages_from_history_few_shots(): + chat_approach = ChatReadRetrieveReadApproach( + None, "", "gpt-35-turbo", "gpt-35-turbo", "", "", "", "", "en-us", "lexicon" + ) + + user_query_request = "What does a Product manager do?" + messages = chat_approach.get_messages_from_history( + system_prompt=chat_approach.query_prompt_template, + model_id=chat_approach.chatgpt_model, + user_content=user_query_request, + history=[], + max_tokens=chat_approach.chatgpt_token_limit - len(user_query_request), + few_shots=chat_approach.query_prompt_few_shots, + ) + # Make sure messages are in the right order + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "user" + assert messages[2]["role"] == "assistant" + assert messages[3]["role"] == "user" + assert messages[4]["role"] == "assistant" + assert messages[5]["role"] == "user" + assert messages[5]["content"] == user_query_request diff --git a/tests/test_messagebuilder.py b/tests/test_messagebuilder.py index 3b6cb687fa..fbc09fa3ba 100644 --- a/tests/test_messagebuilder.py +++ b/tests/test_messagebuilder.py @@ -13,7 +13,7 @@ def test_messagebuilder(): def test_messagebuilder_append(): builder = MessageBuilder("You are a bot.", "gpt-35-turbo") - builder.append_message("user", "Hello, how are you?") + builder.insert_message("user", "Hello, how are you?") assert builder.messages == [ # 1 token, 1 token, 1 token, 5 tokens {"role": "system", "content": "You are a bot."}, @@ -37,7 +37,7 @@ def test_messagebuilder_unicode(): def test_messagebuilder_unicode_append(): builder = MessageBuilder("a\u0301", "gpt-35-turbo") - builder.append_message("user", "a\u0301") + builder.insert_message("user", "a\u0301") assert builder.messages == [ # 1 token, 1 token, 1 token, 1 token {"role": "system", "content": "รก"},