Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix order of few shots sent to MessageBuilder #852

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions app/backend/approaches/chatreadretrieveread.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ def get_messages_from_history(
message_builder = MessageBuilder(system_prompt, model_id)

# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
for shot in few_shots:
message_builder.append_message(shot.get("role"), shot.get("content"))
for shot in reversed(few_shots):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: do we need to reverse order for few shots ? Does it make any difference ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that was actually the main point of this PR. The test will break without reversed() as they're getting put in wrong order now.

message_builder.insert_message(shot.get("role"), shot.get("content"))

append_index = len(few_shots) + 1

message_builder.append_message(self.USER, user_content, index=append_index)
message_builder.insert_message(self.USER, user_content, index=append_index)
total_token_count = message_builder.count_tokens_for_message(message_builder.messages[-1])

newest_to_oldest = list(reversed(history[:-1]))
Expand All @@ -349,7 +349,7 @@ def get_messages_from_history(
if (total_token_count + potential_message_count) > max_tokens:
logging.debug("Reached max tokens of %d, history will be truncated", max_tokens)
break
message_builder.append_message(message["role"], message["content"], index=append_index)
message_builder.insert_message(message["role"], message["content"], index=append_index)
total_token_count += potential_message_count
return message_builder.messages

Expand Down
6 changes: 3 additions & 3 deletions app/backend/approaches/retrievethenread.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ async def run(

# add user question
user_content = q + "\n" + f"Sources:\n {content}"
message_builder.append_message("user", user_content)
message_builder.insert_message("user", user_content)

# Add shots/samples. This helps model to mimic response and make sure they match rules laid out in system message.
message_builder.append_message("assistant", self.answer)
message_builder.append_message("user", self.question)
message_builder.insert_message("assistant", self.answer)
message_builder.insert_message("user", self.question)

messages = message_builder.messages
chatgpt_args = {"deployment_id": self.chatgpt_deployment} if self.openai_host == "azure" else {}
Expand Down
12 changes: 10 additions & 2 deletions app/backend/core/messagebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,22 @@ class MessageBuilder:
token_count (int): The total number of tokens in the conversation.
Methods:
__init__(self, system_content: str, chatgpt_model: str): Initializes the MessageBuilder instance.
append_message(self, role: str, content: str, index: int = 1): Appends a new message to the conversation.
insert_message(self, role: str, content: str, index: int = 1): Inserts a new message to the conversation.
"""

def __init__(self, system_content: str, chatgpt_model: str):
self.messages = [{"role": "system", "content": self.normalize_content(system_content)}]
self.model = chatgpt_model

def append_message(self, role: str, content: str, index: int = 1):
def insert_message(self, role: str, content: str, index: int = 1):
"""
Inserts a message into the conversation at the specified index,
or at index 1 (after system message) if no index is specified.
Args:
role (str): The role of the message sender (either "user" or "system").
content (str): The content of the message.
index (int): The index at which to insert the message.
"""
self.messages.insert(index, {"role": role, "content": self.normalize_content(content)})

def count_tokens_for_message(self, message: dict[str, str]):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_chatapproach.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,27 @@ def test_extract_followup_questions_no_pre_content():
pre_content, followup_questions = chat_approach.extract_followup_questions(content)
assert pre_content == ""
assert followup_questions == ["What is the dress code?"]


def test_get_messages_from_history_few_shots():
chat_approach = ChatReadRetrieveReadApproach(
None, "", "gpt-35-turbo", "gpt-35-turbo", "", "", "", "", "en-us", "lexicon"
)

user_query_request = "What does a Product manager do?"
messages = chat_approach.get_messages_from_history(
system_prompt=chat_approach.query_prompt_template,
model_id=chat_approach.chatgpt_model,
user_content=user_query_request,
history=[],
max_tokens=chat_approach.chatgpt_token_limit - len(user_query_request),
few_shots=chat_approach.query_prompt_few_shots,
)
# Make sure messages are in the right order
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
assert messages[2]["role"] == "assistant"
assert messages[3]["role"] == "user"
assert messages[4]["role"] == "assistant"
assert messages[5]["role"] == "user"
assert messages[5]["content"] == user_query_request
4 changes: 2 additions & 2 deletions tests/test_messagebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_messagebuilder():

def test_messagebuilder_append():
builder = MessageBuilder("You are a bot.", "gpt-35-turbo")
builder.append_message("user", "Hello, how are you?")
builder.insert_message("user", "Hello, how are you?")
assert builder.messages == [
# 1 token, 1 token, 1 token, 5 tokens
{"role": "system", "content": "You are a bot."},
Expand All @@ -37,7 +37,7 @@ def test_messagebuilder_unicode():

def test_messagebuilder_unicode_append():
builder = MessageBuilder("a\u0301", "gpt-35-turbo")
builder.append_message("user", "a\u0301")
builder.insert_message("user", "a\u0301")
assert builder.messages == [
# 1 token, 1 token, 1 token, 1 token
{"role": "system", "content": "á"},
Expand Down