Skip to content

Commit

Permalink
Add message passing format
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Apr 19, 2024
1 parent 881614b commit e903529
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 101 deletions.
1 change: 1 addition & 0 deletions src/transformers/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool
from .default_tools import CalculatorTool, PythonEvaluatorTool
from .agents import ReactAgent, CodeAgent
else:
import sys

Expand Down
230 changes: 133 additions & 97 deletions src/transformers/tools/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.
import importlib.util
import re
import json
from enum import Enum
from ast import literal_eval
from dataclasses import dataclass
Expand Down Expand Up @@ -65,9 +66,7 @@ class PreTool:

HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
"image-transformation",
# "text-download",
"text-to-image",
# "text-to-video",
]


Expand All @@ -82,10 +81,10 @@ def __call__(self):

class MessageRole(str, Enum):
USER = "user"
ASSITANT = "assistant"
ASSISTANT = "assistant"
SYSTEM = "system"
FUNCTION_CALL = "function-call"
FUNCTION_RESPONSE = "function-response"
TOOL_CALL = "tool-call"
TOOL_RESPONSE = "tool-response"

@classmethod
def roles(cls):
Expand Down Expand Up @@ -163,15 +162,16 @@ def clean_code_for_run(code):
code_lines = code_lines[:-1]
code = "\n".join(code_lines)
return code
import json


def parse_json_blob(json_blob: str):
try:
first_accolade_index = json_blob.find("{")
last_accolade_index = [a.start() for a in list(re.finditer('}', json_blob))][-1]
json_blob = json_blob[first_accolade_index:last_accolade_index+1]
return literal_eval(json_blob)
json_blob = json_blob[first_accolade_index:last_accolade_index+1].replace("\\", "")
return json.loads(json_blob)
except Exception as e:
raise ValueError(f"The JSON blob you used is invalid: due to the following error: {e}. Try to correct its formatting.")
raise ValueError(f"The JSON blob you used is invalid: due to the following error: {e}. Make sure to correct its formatting.")


def parse_json_tool_call(json_blob: str):
Expand Down Expand Up @@ -295,20 +295,54 @@ class AgentMaxIterationsError(AgentError):
pass


def get_inner_memory_from_logs(logs: List[Dict[str, Union[str, AgentError]]]) -> str:
"""
Reads past llm_outputs, actions, and observations or errors from the logs.
"""
memory = logs[0]["system_prompt"] + "\n" + logs[0]["task"]
for step_log in logs[1:]:
memory += "\nThought: " + step_log["llm_output"] + "\n"
def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}):
"""
Subsequent messages with the same role will be concatenated to a single message.
if 'error' in step_log:
memory += str(step_log["error"]) + "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches if you can.\n"
Args:
message_list (`List[Dict[str, str]]`): List of chat messages.
"""
final_message_list = []
for message in message_list:
if not set(message.keys()) == {"role", "content"}:
raise ValueError("Message should contain only 'role' and 'content' keys!")

role = message["role"]
if role not in MessageRole.roles():
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")

else:
memory += "Observation: " + step_log["observation"]
return memory
if role in role_conversions:
message["role"] = role_conversions[role]

if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
final_message_list[-1]["content"] += "\n" + message["content"]
else:
final_message_list.append(message)
return final_message_list


llama_role_conversions = {
MessageRole.SYSTEM: MessageRole.USER,
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}

class LLMEngine:
def __init__(self, client):
self.client = client

def call(self, messages: List[Dict[str, str]], stop=["Output:"]) -> str:
# Get clean message list
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)

# Get answer
response = self.client.chat_completion(messages, stop=stop, max_tokens=1500)
response = response.choices[0].message.content

# Remove stop sequences from the answer
for stop_seq in stop:
if response[-len(stop_seq) :] == stop_seq:
response = response[: -len(stop_seq)]
return response


class Agent:
Expand All @@ -322,6 +356,7 @@ def __init__(
max_iterations=1,
tool_parser=parse_json_tool_call,
add_base_tools: bool = False,
verbose=False,
):

self.agent_name = self.__class__.__name__
Expand All @@ -335,23 +370,59 @@ def __init__(

self._toolbox = Toolbox(tools, add_base_tools=add_base_tools)

self.system_message = {
"role": MessageRole.SYSTEM,
"content": format_prompt(self._toolbox, self.prompt_template, self.tool_description_template)
}
self.system_prompt = format_prompt(self._toolbox, self.system_prompt_template, self.tool_description_template)
self.messages = []
self.prompt = None
self.logs = []

if verbose:
logging.set_verbosity_debug()


@property
def toolbox(self) -> Dict[str, Tool]:
"""Get the toolbox currently available to the agent"""
return self._toolbox


def get_inner_memory_from_logs(self) -> str:
"""
Reads past llm_outputs, actions, and observations or errors from the logs.
"""
prompt_message = {
"role": MessageRole.SYSTEM,
"content": self.logs[0]["system_prompt"]
}
task_message ={
"role": MessageRole.USER,
"content": "Task: " + self.logs[0]["task"],
}
memory = [prompt_message, task_message]

for step_log in self.logs[1:]:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": "Thought: " + step_log["llm_output"] + "\n"
}
memory.append(thought_message)

if 'error' in step_log:
message_content = "Error: " + str(step_log["error"]) + "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches if you can.\n"
else:
message_content = f"Observation: {step_log['observation']}"
tool_response_message = {
"role": MessageRole.TOOL_RESPONSE,
"content": message_content
}
memory.append(tool_response_message)
return memory


def show_message_history(self):
self.log.info('\n'.join(self.messages))



def extract_action(self, llm_output: str, split_token: str) -> str:
"""
Parse action from the LLM output
Expand Down Expand Up @@ -392,13 +463,7 @@ def execute(self, tool_name: str, arguments: Dict[str, str]) -> None:
if value in self.state:
arguments[key] = self.state[value]
observation = self.toolbox.tools[tool_name](**arguments)
observation_message = {
"role": MessageRole.FUNCTION_RESPONSE,
"content": "Observation: " + observation.strip()
}
self.log.info(observation_message)
self.memory.append(observation_message)
return observation_message
return observation

except Exception as e:
raise AgentExecutionError(
Expand All @@ -410,26 +475,6 @@ def run(self, **kwargs):
"""To be implemented in the child class"""
pass

def add_message(self, message: Dict[str, str]):
"""
Append provided message to the message history of the current agent run.
Subsequent messages with the same role will be concatenated to a single message.
Args:
message (`Dict[str, str]`): Chat message with corresponding role.
"""

if not set(message.keys()) == {"role", "content"}:
raise ValueError("Message should contain only 'role' and 'content' keys!")

if role := message["role"] not in MessageRole.roles():
raise ValueError(f"Incorrect role {role}, only {MessageRole.roles()} are supported for now.")

if len(self.messages) > 0 and self.messages[-1]["role"] == message["role"]:
self.messages[-1]["content"] += "\n" + message["content"]
else:
self.messages.append(message)


class CodeAgent(Agent):
"""
Expand Down Expand Up @@ -498,20 +543,18 @@ def run(self, task, return_generated_code=False, **kwargs):
self.log.info("====Executing with this prompt====")
self.log.info(self.prompt)
llm_output = self.llm_engine(self.prompt, stop=["Task:"])
self.memory = [self.system_message]

self.task = task
task_message = {
"role": MessageRole.USER,
"content": f"Task: {self.task}"
}
self.add_message(task_message)
task_message = f"Task: {task}"

self.log.info("====Executing with this prompt====")
self.show_message_history()
self.logs.append({"task": task_message, "system_prompt": self.system_prompt})

memory = self.get_inner_memory_from_logs()

self.log.info("====Executing with these messages====")
self.log.info(memory)

# Run LLM
llm_output = self.llm_engine(self.messages, stop=["Task:"])
llm_output = self.llm_engine(memory, stop=["Task:"])

if return_generated_code:
return llm_output["content"]
Expand Down Expand Up @@ -604,42 +647,26 @@ def run(self, task, **kwargs):
self.system_prompt = format_prompt(self._toolbox, self.system_prompt_template, self.tool_description_template)

self.state=kwargs.copy()
if '<<additional_args>>' in self.system_prompt:
self.system_prompt = self.system_prompt.replace('<<additional_args>>', str(self.state))

self.task = task
task_message = f"Task: {self.task}"
if '<<additional_args>>' in self.system_prompt and len(self.state) > 0:
self.system_prompt = self.system_prompt.replace(
'<<additional_args>>',
f"You have been provided with these initial arguments, that you should absolutely use if needed rather than hallucinating arguments: {str(self.state)}."
)

self.log.info("=====New task=====")
self.log.debug("System prompt is as follows:")
self.log.debug(self.system_prompt)
self.logs.append({"task": task_message, "system_prompt": self.system_prompt})


self.memory = [self.system_message]

self.task = task
task_message = {
"role": MessageRole.USER,
"content": f"Task: {self.task}"
}
self.add_message(task_message)
self.logs.append({"system_prompt": self.system_prompt, "task": task})

final_answer = None
iteration = 0

while not final_answer and iteration < self.max_iterations:
self.logs.append({})
try:
final_answer = self.step()
except AgentError as e:
self.log.error(e)
self.logs[-1]["error"] = e
error_message = {
"role": MessageRole.USER,
"content": str(e) + ". Now let's retry."
}
self.add_message(error_message)
finally:
iteration += 1

Expand All @@ -656,17 +683,23 @@ def step(self):
"""
Runs agent step with the current prompt (task + state).
"""
agent_memory = get_inner_memory_from_logs(self.logs[:-1])
agent_memory = self.get_inner_memory_from_logs()
self.logs[-1]["agent_memory"] = agent_memory.copy()

self.prompt = agent_memory
# self.prompt = agent_memory + "\nThought: " # prepend the answer to steer the llm
self.log.debug("=====New step=====")

# Add new step in logs
self.logs.append({})

self.prompt = agent_memory + "\nThought: " # prepend the answer to steer the llm
self.log.info("=====New step=====")
self.log.info("=====Calling LLM with these messages:=====")
self.show_message_history()
self.log.info(agent_memory)

if self.llm_engine_grammar:
llm_output = self.llm_engine(self.messages, stop=["Observation:"], grammar=self.llm_engine_grammar)
llm_output = self.llm_engine(self.prompt, stop=["Observation:"], grammar=self.llm_engine_grammar)
else:
llm_output = self.llm_engine(self.messages, stop=["Observation:"])
llm_output = self.llm_engine(self.prompt, stop=["Observation:"])
self.log.debug("=====Output message of the LLM:=====")
self.log.debug(llm_output)
self.logs[-1]["llm_output"] = llm_output
Expand All @@ -675,7 +708,6 @@ def step(self):
self.log.debug("=====Extracting action=====")
rationale, action = self.extract_action(
llm_output=llm_output,
llm_output=llm_output["content"],
split_token="Action:"
)

Expand All @@ -685,8 +717,10 @@ def step(self):
raise AgentParsingError(f"Could not parse the given action: {e}.")

self.logs[-1]["rationale"] = rationale
self.logs[-1]["tool"] = tool_name
self.logs[-1]["arguments"] = arguments
self.logs[-1]["tool_call"] = {
"tool_name": tool_name,
"tool_arguments": arguments
}

# Execute
if tool_name == "final_answer":
Expand All @@ -699,9 +733,10 @@ def step(self):
return answer
else:
observation = self.execute(tool_name, arguments)

observation_type = type(observation)
if observation_type in [str, int, float, bool]:
observation_message = str(observation).strip()
updated_information = str(observation).strip()
else: # if the execution result is an object, store it
if observation_type == Image.Image:
observation_name = "image.png"
Expand All @@ -710,7 +745,8 @@ def step(self):
# TODO: improve observation name choice

self.state[observation_name] = observation
observation_message = f"Stored '{observation_name}' in memory."
self.log.info(observation_message)
self.logs[-1]["observation"] = observation_message
updated_information = f"Stored '{observation_name}' in memory."

self.log.info(updated_information)
self.logs[-1]["observation"] = updated_information
return None
Loading

0 comments on commit e903529

Please sign in to comment.