-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reformulate query if the query is duplicate #69
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,18 @@ | ||
import json | ||
from os import environ | ||
from typing import List, Optional | ||
|
||
from pydantic import ValidationError | ||
import openai | ||
|
||
from langchain.chains.llm import LLMChain | ||
from langchain.chat_models.base import BaseChatModel | ||
from langchain.schema import AIMessage, BaseMessage, Document, HumanMessage | ||
from langchain.tools.base import BaseTool | ||
from langchain.tools.human.tool import HumanInputRun | ||
from langchain.vectorstores.base import VectorStoreRetriever | ||
from pydantic import ValidationError | ||
|
||
from output_parser import BaseTaskOutputParser, TaskOutputParser | ||
from post_processors import md_link_to_slack | ||
from prompt import SlackBotPrompt | ||
from pydantic import ValidationError | ||
|
||
|
||
class TaskAgent: | ||
|
@@ -93,6 +91,7 @@ def run(self, task: str) -> str: | |
|
||
# Interaction Loop | ||
|
||
previous_action = "" | ||
while True: | ||
# Discontinue if continuous limit is reached | ||
loop_count = self.loop_count | ||
|
@@ -117,7 +116,7 @@ def run(self, task: str) -> str: | |
user_input=user_input, | ||
) | ||
except openai.error.APIError as e: | ||
return f"OpenAI API returned an API Error: {e}" | ||
return f"OpenAI API returned an API Error: {e}" | ||
except openai.error.APIConnectionError as e: | ||
return f"Failed to connect to OpenAI API: {e}" | ||
except openai.error.RateLimitError as e: | ||
|
@@ -130,7 +129,6 @@ def run(self, task: str) -> str: | |
return f"OpenAI API Service unavailable: {e}" | ||
except openai.error.InvalidRequestError as e: | ||
return f"OpenAI API invalid request error: {e}" | ||
|
||
|
||
assistant_reply = self.chain.run( | ||
task=task, | ||
|
@@ -177,6 +175,33 @@ def run(self, task: str) -> str: | |
action = self.output_parser.parse(assistant_reply) | ||
print("action:", action) | ||
tools = {t.name: t for t in self.tools} | ||
if action == previous_action: | ||
if action.name == "Search" or action.name == "Context Search": | ||
print( | ||
"Action name: ", action.name, "\nStart reformulating the query" | ||
) | ||
instruction = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, this prompt only asks for a different input for the same tool. Is this intended? |
||
f"You want to search for useful information to answer the query: {task}." | ||
f"The original query is: {action.args['query']}" | ||
f"Reformulate the query so that it can be used to search for relevant information." | ||
f"Only return one query instead of multiple queries." | ||
f"Reformulated query:\n\n" | ||
) | ||
openai.api_key = environ.get("OPENAI_KEY") | ||
response = openai.Completion.create( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may want to use |
||
engine="text-davinci-003", | ||
prompt=" ".join(str(i) for i in self.previous_message) | ||
+ "\n" | ||
+ instruction, | ||
temperature=0.7, | ||
max_tokens=1024, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
) | ||
reformulated_query = response["choices"][0]["text"] | ||
action.args["query"] = reformulated_query | ||
|
||
if action.name == "finish": | ||
self.loop_count = self.max_iterations | ||
result = "Finished task. " | ||
|
@@ -220,6 +245,7 @@ def run(self, task: str) -> str: | |
|
||
# self.memory.add_documents([Document(page_content=memory_to_add)]) | ||
self.previous_message.append(HumanMessage(content=memory_to_add)) | ||
previous_action = action | ||
|
||
def set_user_input(self, user_input: str): | ||
result = f"Command UserInput returned: {user_input}" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recommend we drop this condition since we may want to check duplication for all tools. (We don't need to worry about the
finish
tool for now as it will be separated and it will also break the loop)