Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: separate files for ToolRunner and artifact handlers #256

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/rai/rai/agents/conversational_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from langgraph.prebuilt.tool_node import tools_condition
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai.agents.state_based import ToolRunner
from rai.agents.tool_runner import ToolRunner

loggers_type = Union[RcutilsLogger, logging.Logger]

Expand Down
181 changes: 5 additions & 176 deletions src/rai/rai/agents/state_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,24 @@
#


import json
import logging
import pickle
import time
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
TypedDict,
Union,
cast,
)
from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict, Union

from langchain.chat_models.base import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as create_tool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt.tool_node import str_output
from langgraph.utils.runnable import RunnableCallable
from pydantic import BaseModel, Field, ValidationError
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai.messages import (
HumanMultimodalMessage,
MultimodalArtifact,
ToolMultimodalMessage,
)
from rai.agents.tool_runner import ToolRunner
from rai.messages import HumanMultimodalMessage

loggers_type = Union[RcutilsLogger, logging.Logger]

Expand Down Expand Up @@ -87,149 +59,6 @@ class Report(BaseModel):
)


def get_stored_artifacts(
tool_call_id: str, db_path="artifact_database.pkl"
) -> List[Any]:
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
return []

with db_path.open("rb") as db:
artifact_database = pickle.load(db)
if tool_call_id in artifact_database:
return artifact_database[tool_call_id]

return []


def store_artifacts(
tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl"
):
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
artifact_database = {}
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)
with open("artifact_database.pkl", "rb") as file:
artifact_database = pickle.load(file)
if tool_call_id not in artifact_database:
artifact_database[tool_call_id] = artifacts
else:
artifact_database[tool_call_id].extend(artifacts)
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)


class ToolRunner(RunnableCallable):
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
logger: loggers_type,
) -> None:
super().__init__(self._func, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
self.logger = logger

def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any:
config["max_concurrency"] = (
1 # TODO(maciejmajek): use better mechanism for task queueing
)
if messages := input.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")

if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")

def run_one(call: ToolCall):
self.logger.info(f"Running tool: {call['name']}")
artifact = None

try:
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore
self.logger.info(
"Tool output (max 100 chars): " + str(output.content[0:100])
)
except ValidationError as e:
errors = e.errors()
for error in errors:
error.pop(
"url"
) # get rid of the https://errors.pydantic.dev/... url

error_message = f"""
Validation error in tool {call["name"]}:
{e.title}
Number of errors: {e.error_count()}
Errors:
{json.dumps(errors, indent=2)}
"""
self.logger.info(error_message)
output = ToolMessage(
content=error_message,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
except Exception as e:
self.logger.info(f'Error in "{call["name"]}", error: {e}')
output = ToolMessage(
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if output.artifact is not None:
artifact = output.artifact
if not isinstance(artifact, dict):
raise ValueError(
"Artifact must be a dictionary with optional keys: 'images', 'audios'"
)

artifact = cast(MultimodalArtifact, artifact)
store_artifacts(output.tool_call_id, [artifact])

if artifact is not None: # multimodal case
return ToolMultimodalMessage(
content=str_output(output.content),
name=call["name"],
tool_call_id=call["id"],
images=artifact.get("images", []),
audios=artifact.get("audios", []),
)

return output

with get_executor_for_config(config) as executor:
raw_outputs = [*executor.map(run_one, message.tool_calls)]
outputs: List[Any] = []
for raw_output in raw_outputs:
if isinstance(raw_output, ToolMultimodalMessage):
outputs.extend(
raw_output.postprocess()
) # openai please allow tool messages with images!
else:
outputs.append(raw_output)

# because we can't answer an aiMessage with an alternating sequence of tool and human messages
# we sort the messages by type so that the tool messages are sent first
# for more information see implementation of ToolMultimodalMessage.postprocess
outputs.sort(key=lambda x: x.__class__.__name__, reverse=True)
input["messages"].extend(outputs)
return input


def tools_condition(
state: Union[list[AnyMessage], dict[str, Any]],
) -> Literal["tools", "reporter"]:
Expand Down
142 changes: 142 additions & 0 deletions src/rai/rai/agents/tool_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as create_tool
from langgraph.prebuilt.tool_node import str_output
from langgraph.utils.runnable import RunnableCallable
from pydantic import ValidationError
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai.messages import MultimodalArtifact, ToolMultimodalMessage
from rai.utils.artifacts import store_artifacts


class ToolRunner(RunnableCallable):
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
logger: Union[RcutilsLogger, logging.Logger],
) -> None:
super().__init__(self._func, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
self.logger = logger

def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any:
config["max_concurrency"] = (
1 # TODO(maciejmajek): use better mechanism for task queueing
)
if messages := input.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")

if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")

def run_one(call: ToolCall):
self.logger.info(f"Running tool: {call['name']}")
artifact = None

try:
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore
self.logger.info(
"Tool output (max 100 chars): " + str(output.content[0:100])
)
except ValidationError as e:
errors = e.errors()
for error in errors:
error.pop(
"url"
) # get rid of the https://errors.pydantic.dev/... url

error_message = f"""
Validation error in tool {call["name"]}:
{e.title}
Number of errors: {e.error_count()}
Errors:
{json.dumps(errors, indent=2)}
"""
self.logger.info(error_message)
output = ToolMessage(
content=error_message,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
except Exception as e:
self.logger.info(f'Error in "{call["name"]}", error: {e}')
output = ToolMessage(
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if output.artifact is not None:
artifact = output.artifact
if not isinstance(artifact, dict):
raise ValueError(
"Artifact must be a dictionary with optional keys: 'images', 'audios'"
)

artifact = cast(MultimodalArtifact, artifact)
store_artifacts(output.tool_call_id, [artifact])

if artifact is not None and (
len(artifact.get("images", [])) > 0
or len(artifact.get("audios", [])) > 0
): # multimodal case, we currently support images and audios artifacts
return ToolMultimodalMessage(
content=str_output(output.content),
name=call["name"],
tool_call_id=call["id"],
images=artifact.get("images", []),
audios=artifact.get("audios", []),
)

return output

with get_executor_for_config(config) as executor:
raw_outputs = [*executor.map(run_one, message.tool_calls)]
outputs: List[Any] = []
for raw_output in raw_outputs:
if isinstance(raw_output, ToolMultimodalMessage):
outputs.extend(
raw_output.postprocess()
) # openai please allow tool messages with images!
else:
outputs.append(raw_output)

# because we can't answer an aiMessage with an alternating sequence of tool and human messages
# we sort the messages by type so that the tool messages are sent first
# for more information see implementation of ToolMultimodalMessage.postprocess
outputs.sort(key=lambda x: x.__class__.__name__, reverse=True)
input["messages"].extend(outputs)
return input
Loading