Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reformulate query if the query is duplicate #69

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions apps/slackbot/task_agent.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import json
from os import environ
from typing import List, Optional

from pydantic import ValidationError
import openai

from langchain.chains.llm import LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, Document, HumanMessage
from langchain.tools.base import BaseTool
from langchain.tools.human.tool import HumanInputRun
from langchain.vectorstores.base import VectorStoreRetriever
from pydantic import ValidationError

from output_parser import BaseTaskOutputParser, TaskOutputParser
from post_processors import md_link_to_slack
from prompt import SlackBotPrompt
from pydantic import ValidationError


class TaskAgent:
Expand Down Expand Up @@ -93,6 +91,7 @@ def run(self, task: str) -> str:

# Interaction Loop

previous_action = ""
while True:
# Discontinue if continuous limit is reached
loop_count = self.loop_count
Expand All @@ -117,7 +116,7 @@ def run(self, task: str) -> str:
user_input=user_input,
)
except openai.error.APIError as e:
return f"OpenAI API returned an API Error: {e}"
return f"OpenAI API returned an API Error: {e}"
except openai.error.APIConnectionError as e:
return f"Failed to connect to OpenAI API: {e}"
except openai.error.RateLimitError as e:
Expand All @@ -130,7 +129,6 @@ def run(self, task: str) -> str:
return f"OpenAI API Service unavailable: {e}"
except openai.error.InvalidRequestError as e:
return f"OpenAI API invalid request error: {e}"


assistant_reply = self.chain.run(
task=task,
Expand Down Expand Up @@ -177,6 +175,33 @@ def run(self, task: str) -> str:
action = self.output_parser.parse(assistant_reply)
print("action:", action)
tools = {t.name: t for t in self.tools}
if action == previous_action:
if action.name == "Search" or action.name == "Context Search":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend we drop this condition since we may want to check duplication for all tools. (We don't need to worry about the finish tool for now as it will be separated and it will also break the loop)

print(
"Action name: ", action.name, "\nStart reformulating the query"
)
instruction = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this prompt only asks for a different input for the same tool. Is this intended?

f"You want to search for useful information to answer the query: {task}."
f"The original query is: {action.args['query']}"
f"Reformulate the query so that it can be used to search for relevant information."
f"Only return one query instead of multiple queries."
f"Reformulated query:\n\n"
)
openai.api_key = environ.get("OPENAI_KEY")
response = openai.Completion.create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to use self.llm so that we don't need to provide extra configurations of the LLM here

engine="text-davinci-003",
prompt=" ".join(str(i) for i in self.previous_message)
+ "\n"
+ instruction,
temperature=0.7,
max_tokens=1024,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
reformulated_query = response["choices"][0]["text"]
action.args["query"] = reformulated_query

if action.name == "finish":
self.loop_count = self.max_iterations
result = "Finished task. "
Expand Down Expand Up @@ -220,6 +245,7 @@ def run(self, task: str) -> str:

# self.memory.add_documents([Document(page_content=memory_to_add)])
self.previous_message.append(HumanMessage(content=memory_to_add))
previous_action = action

def set_user_input(self, user_input: str):
result = f"Command UserInput returned: {user_input}"
Expand Down