Skip to content

Commit

Permalink
refactor: move ToolRunner to separate file, move artifacts handlers t…
Browse files Browse the repository at this point in the history
…o separate file
  • Loading branch information
maciejmajek committed Sep 30, 2024
1 parent 3282d76 commit 364ad48
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 178 deletions.
2 changes: 1 addition & 1 deletion src/rai/rai/agents/conversational_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from langgraph.prebuilt.tool_node import tools_condition
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai.agents.state_based import ToolRunner
from rai.agents.tool_runner import ToolRunner

loggers_type = Union[RcutilsLogger, logging.Logger]

Expand Down
181 changes: 5 additions & 176 deletions src/rai/rai/agents/state_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,24 @@
#


import json
import logging
import pickle
import time
from functools import partial
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
TypedDict,
Union,
cast,
)
from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict, Union

from langchain.chat_models.base import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as create_tool
from langgraph.graph import END, START, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt.tool_node import str_output
from langgraph.utils.runnable import RunnableCallable
from pydantic import BaseModel, Field, ValidationError
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai.messages import (
HumanMultimodalMessage,
MultimodalArtifact,
ToolMultimodalMessage,
)
from rai.agents.tool_runner import ToolRunner
from rai.messages import HumanMultimodalMessage

loggers_type = Union[RcutilsLogger, logging.Logger]

Expand Down Expand Up @@ -87,149 +59,6 @@ class Report(BaseModel):
)


def get_stored_artifacts(
tool_call_id: str, db_path="artifact_database.pkl"
) -> List[Any]:
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
return []

with db_path.open("rb") as db:
artifact_database = pickle.load(db)
if tool_call_id in artifact_database:
return artifact_database[tool_call_id]

return []


def store_artifacts(
tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl"
):
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
artifact_database = {}
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)
with open("artifact_database.pkl", "rb") as file:
artifact_database = pickle.load(file)
if tool_call_id not in artifact_database:
artifact_database[tool_call_id] = artifacts
else:
artifact_database[tool_call_id].extend(artifacts)
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)


class ToolRunner(RunnableCallable):
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
logger: loggers_type,
) -> None:
super().__init__(self._func, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
self.logger = logger

def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any:
config["max_concurrency"] = (
1 # TODO(maciejmajek): use better mechanism for task queueing
)
if messages := input.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")

if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")

def run_one(call: ToolCall):
self.logger.info(f"Running tool: {call['name']}")
artifact = None

try:
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore
self.logger.info(
"Tool output (max 100 chars): " + str(output.content[0:100])
)
except ValidationError as e:
errors = e.errors()
for error in errors:
error.pop(
"url"
) # get rid of the https://errors.pydantic.dev/... url

error_message = f"""
Validation error in tool {call["name"]}:
{e.title}
Number of errors: {e.error_count()}
Errors:
{json.dumps(errors, indent=2)}
"""
self.logger.info(error_message)
output = ToolMessage(
content=error_message,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
except Exception as e:
self.logger.info(f'Error in "{call["name"]}", error: {e}')
output = ToolMessage(
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if output.artifact is not None:
artifact = output.artifact
if not isinstance(artifact, dict):
raise ValueError(
"Artifact must be a dictionary with optional keys: 'images', 'audios'"
)

artifact = cast(MultimodalArtifact, artifact)
store_artifacts(output.tool_call_id, [artifact])

if artifact is not None: # multimodal case
return ToolMultimodalMessage(
content=str_output(output.content),
name=call["name"],
tool_call_id=call["id"],
images=artifact.get("images", []),
audios=artifact.get("audios", []),
)

return output

with get_executor_for_config(config) as executor:
raw_outputs = [*executor.map(run_one, message.tool_calls)]
outputs: List[Any] = []
for raw_output in raw_outputs:
if isinstance(raw_output, ToolMultimodalMessage):
outputs.extend(
raw_output.postprocess()
) # openai please allow tool messages with images!
else:
outputs.append(raw_output)

# because we can't answer an aiMessage with an alternating sequence of tool and human messages
# we sort the messages by type so that the tool messages are sent first
# for more information see implementation of ToolMultimodalMessage.postprocess
outputs.sort(key=lambda x: x.__class__.__name__, reverse=True)
input["messages"].extend(outputs)
return input


def tools_condition(
state: Union[list[AnyMessage], dict[str, Any]],
) -> Literal["tools", "reporter"]:
Expand Down
139 changes: 139 additions & 0 deletions src/rai/rai/agents/tool_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from langchain_core.messages import AIMessage, ToolCall, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool
from langchain_core.tools import tool as create_tool
from langgraph.prebuilt.tool_node import str_output
from langgraph.utils.runnable import RunnableCallable
from pydantic import ValidationError
from rclpy.impl.rcutils_logger import RcutilsLogger

from rai.messages import MultimodalArtifact, ToolMultimodalMessage
from rai.utils.artifacts import store_artifacts


class ToolRunner(RunnableCallable):
def __init__(
self,
tools: Sequence[Union[BaseTool, Callable]],
*,
name: str = "tools",
tags: Optional[list[str]] = None,
logger: Union[RcutilsLogger, logging.Logger],
) -> None:
super().__init__(self._func, name=name, tags=tags, trace=False)
self.tools_by_name: Dict[str, BaseTool] = {}
for tool_ in tools:
if not isinstance(tool_, BaseTool):
tool_ = create_tool(tool_)
self.tools_by_name[tool_.name] = tool_
self.logger = logger

def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any:
config["max_concurrency"] = (
1 # TODO(maciejmajek): use better mechanism for task queueing
)
if messages := input.get("messages", []):
message = messages[-1]
else:
raise ValueError("No message found in input")

if not isinstance(message, AIMessage):
raise ValueError("Last message is not an AIMessage")

def run_one(call: ToolCall):
self.logger.info(f"Running tool: {call['name']}")
artifact = None

try:
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore
self.logger.info(
"Tool output (max 100 chars): " + str(output.content[0:100])
)
except ValidationError as e:
errors = e.errors()
for error in errors:
error.pop(
"url"
) # get rid of the https://errors.pydantic.dev/... url

error_message = f"""
Validation error in tool {call["name"]}:
{e.title}
Number of errors: {e.error_count()}
Errors:
{json.dumps(errors, indent=2)}
"""
self.logger.info(error_message)
output = ToolMessage(
content=error_message,
name=call["name"],
tool_call_id=call["id"],
status="error",
)
except Exception as e:
self.logger.info(f'Error in "{call["name"]}", error: {e}')
output = ToolMessage(
content=f"Failed to run tool. Error: {e}",
name=call["name"],
tool_call_id=call["id"],
status="error",
)

if output.artifact is not None:
artifact = output.artifact
if not isinstance(artifact, dict):
raise ValueError(
"Artifact must be a dictionary with optional keys: 'images', 'audios'"
)

artifact = cast(MultimodalArtifact, artifact)
store_artifacts(output.tool_call_id, [artifact])

if artifact is not None: # multimodal case
return ToolMultimodalMessage(
content=str_output(output.content),
name=call["name"],
tool_call_id=call["id"],
images=artifact.get("images", []),
audios=artifact.get("audios", []),
)

return output

with get_executor_for_config(config) as executor:
raw_outputs = [*executor.map(run_one, message.tool_calls)]
outputs: List[Any] = []
for raw_output in raw_outputs:
if isinstance(raw_output, ToolMultimodalMessage):
outputs.extend(
raw_output.postprocess()
) # openai please allow tool messages with images!
else:
outputs.append(raw_output)

# because we can't answer an aiMessage with an alternating sequence of tool and human messages
# we sort the messages by type so that the tool messages are sent first
# for more information see implementation of ToolMultimodalMessage.postprocess
outputs.sort(key=lambda x: x.__class__.__name__, reverse=True)
input["messages"].extend(outputs)
return input
38 changes: 38 additions & 0 deletions src/rai/rai/utils/artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pickle
from pathlib import Path
from typing import Any, List


def store_artifacts(
tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl"
):
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
artifact_database = {}
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)
with open("artifact_database.pkl", "rb") as file:
artifact_database = pickle.load(file)
if tool_call_id not in artifact_database:
artifact_database[tool_call_id] = artifacts
else:
artifact_database[tool_call_id].extend(artifacts)
with open("artifact_database.pkl", "wb") as file:
pickle.dump(artifact_database, file)


def get_stored_artifacts(
tool_call_id: str, db_path="artifact_database.pkl"
) -> List[Any]:
# TODO(boczekbartek): refactor
db_path = Path(db_path)
if not db_path.is_file():
return []

with db_path.open("rb") as db:
artifact_database = pickle.load(db)
if tool_call_id in artifact_database:
return artifact_database[tool_call_id]

return []
Loading

0 comments on commit 364ad48

Please sign in to comment.