-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add prompt size checking and use default llm
- Loading branch information
1 parent
0f34020
commit 7c154af
Showing
2 changed files
with
44 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,65 @@ | ||
import openai | ||
from os import environ | ||
from langchain.chat_models import ChatOpenAI | ||
from typing import List | ||
|
||
from langchain.base_language import BaseLanguageModel | ||
from langchain.schema import BaseMessage | ||
from langchain.tools import BaseTool | ||
from loguru import logger | ||
|
||
from prompt_generator import PromptGenerator | ||
|
||
class Reflection(): | ||
def __init__(self, tools, action_list = []): | ||
|
||
class Reflection: | ||
def __init__( | ||
self, llm: BaseLanguageModel, tools: List[BaseTool], action_list: List = [] | ||
): | ||
self.llm = llm | ||
self.action_list = action_list | ||
|
||
self.token_counter = self.llm.get_num_tokens | ||
prompt = PromptGenerator() | ||
self.format = prompt.response_format | ||
self.commands = [ | ||
f"{i + 1}. {prompt._generate_command_string(item)}" | ||
for i, item in enumerate(tools) | ||
] | ||
|
||
f"{i + 1}. {prompt._generate_command_string(item)}" | ||
for i, item in enumerate(tools) | ||
] | ||
|
||
def create_message_history( | ||
self, messages: List[BaseMessage], max_token=2000 | ||
) -> str: | ||
result = "" | ||
current_tokens = 0 | ||
|
||
for message in reversed(messages): | ||
current_tokens = current_tokens + self.token_counter(message.content) | ||
if current_tokens > max_token: | ||
break | ||
result = message.type + ": " + message.content + "\n" + result | ||
|
||
return result | ||
|
||
# we will also include previous_messages in the sherpa system | ||
def evaluate_action(self, action, assistant_reply, task, previous_message): | ||
self.action_list.append(action) | ||
if len(self.action_list) == 1: # first action, no previous action | ||
return assistant_reply | ||
else: | ||
previous_action = self.action_list[-2] | ||
if previous_action == action: # duplicate action | ||
message_history = self.create_message_history(previous_message) | ||
if previous_action == action: # duplicate action | ||
instruction = ( | ||
f"You want to solve the task: {task}." | ||
f"The original reply is: {assistant_reply}" | ||
f"Here is all the commands you can choose to use: {self.commands}" | ||
f"Here is previous messages: {previous_message}" | ||
f"Here is previous messages: \n{message_history}\n" | ||
f"We need a new reply by changing neither command.name or command.args.query." | ||
f"Make sure the new reply is different from the original reply by name or query." | ||
f"You should only respond in JSON format as described below without any extra text. Do not return the TaskAction object." | ||
f"Format for the new reply: {self.format}" | ||
f"Ensure the response can be parsed by Python json.loads" | ||
f"New reply:\n\n" | ||
) | ||
openai.api_key = environ.get("OPENAI_API_KEY") | ||
response = openai.Completion.create( | ||
engine="text-davinci-003", | ||
prompt= instruction, | ||
temperature=0.7, | ||
max_tokens=1024, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0 | ||
) | ||
return response['choices'][0]['text'] | ||
logger.warning(f"tokens used: {self.token_counter(instruction)}") | ||
return self.llm.predict(instruction) | ||
else: | ||
return assistant_reply | ||
return assistant_reply |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters