-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: move ToolRunner to separate file, move artifacts handlers t…
…o separate file
- Loading branch information
1 parent
3282d76
commit 364ad48
Showing
5 changed files
with
184 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import json | ||
import logging | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast | ||
|
||
from langchain_core.messages import AIMessage, ToolCall, ToolMessage | ||
from langchain_core.runnables import RunnableConfig | ||
from langchain_core.runnables.config import get_executor_for_config | ||
from langchain_core.tools import BaseTool | ||
from langchain_core.tools import tool as create_tool | ||
from langgraph.prebuilt.tool_node import str_output | ||
from langgraph.utils.runnable import RunnableCallable | ||
from pydantic import ValidationError | ||
from rclpy.impl.rcutils_logger import RcutilsLogger | ||
|
||
from rai.messages import MultimodalArtifact, ToolMultimodalMessage | ||
from rai.utils.artifacts import store_artifacts | ||
|
||
|
||
class ToolRunner(RunnableCallable): | ||
def __init__( | ||
self, | ||
tools: Sequence[Union[BaseTool, Callable]], | ||
*, | ||
name: str = "tools", | ||
tags: Optional[list[str]] = None, | ||
logger: Union[RcutilsLogger, logging.Logger], | ||
) -> None: | ||
super().__init__(self._func, name=name, tags=tags, trace=False) | ||
self.tools_by_name: Dict[str, BaseTool] = {} | ||
for tool_ in tools: | ||
if not isinstance(tool_, BaseTool): | ||
tool_ = create_tool(tool_) | ||
self.tools_by_name[tool_.name] = tool_ | ||
self.logger = logger | ||
|
||
def _func(self, input: dict[str, Any], config: RunnableConfig) -> Any: | ||
config["max_concurrency"] = ( | ||
1 # TODO(maciejmajek): use better mechanism for task queueing | ||
) | ||
if messages := input.get("messages", []): | ||
message = messages[-1] | ||
else: | ||
raise ValueError("No message found in input") | ||
|
||
if not isinstance(message, AIMessage): | ||
raise ValueError("Last message is not an AIMessage") | ||
|
||
def run_one(call: ToolCall): | ||
self.logger.info(f"Running tool: {call['name']}") | ||
artifact = None | ||
|
||
try: | ||
output = self.tools_by_name[call["name"]].invoke(call, config) # type: ignore | ||
self.logger.info( | ||
"Tool output (max 100 chars): " + str(output.content[0:100]) | ||
) | ||
except ValidationError as e: | ||
errors = e.errors() | ||
for error in errors: | ||
error.pop( | ||
"url" | ||
) # get rid of the https://errors.pydantic.dev/... url | ||
|
||
error_message = f""" | ||
Validation error in tool {call["name"]}: | ||
{e.title} | ||
Number of errors: {e.error_count()} | ||
Errors: | ||
{json.dumps(errors, indent=2)} | ||
""" | ||
self.logger.info(error_message) | ||
output = ToolMessage( | ||
content=error_message, | ||
name=call["name"], | ||
tool_call_id=call["id"], | ||
status="error", | ||
) | ||
except Exception as e: | ||
self.logger.info(f'Error in "{call["name"]}", error: {e}') | ||
output = ToolMessage( | ||
content=f"Failed to run tool. Error: {e}", | ||
name=call["name"], | ||
tool_call_id=call["id"], | ||
status="error", | ||
) | ||
|
||
if output.artifact is not None: | ||
artifact = output.artifact | ||
if not isinstance(artifact, dict): | ||
raise ValueError( | ||
"Artifact must be a dictionary with optional keys: 'images', 'audios'" | ||
) | ||
|
||
artifact = cast(MultimodalArtifact, artifact) | ||
store_artifacts(output.tool_call_id, [artifact]) | ||
|
||
if artifact is not None: # multimodal case | ||
return ToolMultimodalMessage( | ||
content=str_output(output.content), | ||
name=call["name"], | ||
tool_call_id=call["id"], | ||
images=artifact.get("images", []), | ||
audios=artifact.get("audios", []), | ||
) | ||
|
||
return output | ||
|
||
with get_executor_for_config(config) as executor: | ||
raw_outputs = [*executor.map(run_one, message.tool_calls)] | ||
outputs: List[Any] = [] | ||
for raw_output in raw_outputs: | ||
if isinstance(raw_output, ToolMultimodalMessage): | ||
outputs.extend( | ||
raw_output.postprocess() | ||
) # openai please allow tool messages with images! | ||
else: | ||
outputs.append(raw_output) | ||
|
||
# because we can't answer an aiMessage with an alternating sequence of tool and human messages | ||
# we sort the messages by type so that the tool messages are sent first | ||
# for more information see implementation of ToolMultimodalMessage.postprocess | ||
outputs.sort(key=lambda x: x.__class__.__name__, reverse=True) | ||
input["messages"].extend(outputs) | ||
return input |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pickle | ||
from pathlib import Path | ||
from typing import Any, List | ||
|
||
|
||
def store_artifacts( | ||
tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl" | ||
): | ||
# TODO(boczekbartek): refactor | ||
db_path = Path(db_path) | ||
if not db_path.is_file(): | ||
artifact_database = {} | ||
with open("artifact_database.pkl", "wb") as file: | ||
pickle.dump(artifact_database, file) | ||
with open("artifact_database.pkl", "rb") as file: | ||
artifact_database = pickle.load(file) | ||
if tool_call_id not in artifact_database: | ||
artifact_database[tool_call_id] = artifacts | ||
else: | ||
artifact_database[tool_call_id].extend(artifacts) | ||
with open("artifact_database.pkl", "wb") as file: | ||
pickle.dump(artifact_database, file) | ||
|
||
|
||
def get_stored_artifacts( | ||
tool_call_id: str, db_path="artifact_database.pkl" | ||
) -> List[Any]: | ||
# TODO(boczekbartek): refactor | ||
db_path = Path(db_path) | ||
if not db_path.is_file(): | ||
return [] | ||
|
||
with db_path.open("rb") as db: | ||
artifact_database = pickle.load(db) | ||
if tool_call_id in artifact_database: | ||
return artifact_database[tool_call_id] | ||
|
||
return [] |
Oops, something went wrong.