Skip to content

Commit

Permalink
feat: add POST route for testing tool execution via tool_id (#2139)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Dec 3, 2024
1 parent 2f142a3 commit 9417b11
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 12 deletions.
12 changes: 12 additions & 0 deletions letta/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,15 @@ class ToolUpdate(LettaBase):
class Config:
extra = "ignore" # Allows extra fields without validation errors
# TODO: Remove this, and clean usage of ToolUpdate everywhere else


class ToolRun(LettaBase):
id: str = Field(..., description="The ID of the tool to run.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")


class ToolRunFromSource(LettaBase):
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
name: Optional[str] = Field(..., description="The name of the tool to run.")
source_code: str = Field(None, description="The source code of the function.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
38 changes: 37 additions & 1 deletion letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from letta.errors import LettaToolCreateError
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.schemas.letta_message import FunctionReturn
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer

Expand Down Expand Up @@ -159,6 +160,41 @@ def add_base_tools(
return server.tool_manager.add_base_tools(actor=actor)


# NOTE: can re-enable if needed
# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool")
# def run_tool(
# server: SyncServer = Depends(get_letta_server),
# request: ToolRun = Body(...),
# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
# ):
# """
# Run an existing tool on provided arguments
# """
# actor = server.get_user_or_default(user_id=user_id)

# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)


@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")
def run_tool_from_source(
server: SyncServer = Depends(get_letta_server),
request: ToolRunFromSource = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attempt to build a tool from source, then run it on the provided arguments
"""
actor = server.get_user_or_default(user_id=user_id)

return server.run_tool_from_source(
tool_source=request.source_code,
tool_source_type=request.source_type,
tool_args=request.args,
tool_name=request.name,
user_id=actor.id,
)


# Specific routes for Composio


Expand Down
110 changes: 108 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# inspecting tools
import json
import os
import traceback
import warnings
Expand Down Expand Up @@ -56,7 +57,7 @@
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message import FunctionReturn, LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary,
Expand All @@ -78,9 +79,10 @@
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import create_random_username, json_dumps, json_loads
from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads

logger = get_logger(__name__)

Expand Down Expand Up @@ -1764,6 +1766,110 @@ def get_agent_block_by_label(self, user_id: str, agent_id: str, label: str) -> B
return block
return None

# def run_tool(self, tool_id: str, tool_args: str, user_id: str) -> FunctionReturn:
# """Run a tool using the sandbox and return the result"""

# try:
# tool_args_dict = json.loads(tool_args)
# except json.JSONDecodeError:
# raise ValueError("Invalid JSON string for tool_args")

# # Get the tool by ID
# user = self.user_manager.get_user_by_id(user_id=user_id)
# tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user)
# if tool.name is None:
# raise ValueError(f"Tool with id {tool_id} does not have a name")

# # TODO eventually allow using agent state in tools
# agent_state = None

# try:
# sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id).run(agent_state=agent_state)
# if sandbox_run_result is None:
# raise ValueError(f"Tool with id {tool_id} returned execution with None")
# function_response = str(sandbox_run_result.func_return)

# return FunctionReturn(
# id="null",
# function_call_id="null",
# date=get_utc_time(),
# status="success",
# function_return=function_response,
# )
# except Exception as e:
# # same as agent.py
# from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT

# error_msg = f"Error executing tool {tool.name}: {e}"
# if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
# error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]

# return FunctionReturn(
# id="null",
# function_call_id="null",
# date=get_utc_time(),
# status="error",
# function_return=error_msg,
# )

def run_tool_from_source(
self,
user_id: str,
tool_args: str,
tool_source: str,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
) -> FunctionReturn:
"""Run a tool from source code"""

try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string for tool_args")

if tool_source_type is not None and tool_source_type != "python":
raise ValueError("Only Python source code is supported at this time")

# NOTE: we're creating a floating Tool object and NOT persiting to DB
tool = Tool(
name=tool_name,
source_code=tool_source,
)
assert tool.name is not None, "Failed to create tool object"

# TODO eventually allow using agent state in tools
agent_state = None

# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
if sandbox_run_result is None:
raise ValueError(f"Tool with id {tool.id} returned execution with None")
function_response = str(sandbox_run_result.func_return)

return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="success",
function_return=function_response,
)
except Exception as e:
# same as agent.py
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT

error_msg = f"Error executing tool {tool.name}: {e}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]

return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="error",
function_return=error_msg,
)

# Composio wrappers
def get_composio_apps(self) -> List["AppModel"]:
"""Get a list of all Composio apps with actions"""
Expand Down
23 changes: 14 additions & 9 deletions letta/services/tool_execution_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
Expand All @@ -27,7 +28,7 @@ class ToolExecutionSandbox:
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"

def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False):
def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None):
self.tool_name = tool_name
self.args = args

Expand All @@ -36,14 +37,18 @@ def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=Fals
# agent_state is the state of the agent that invoked this run
self.user = UserManager().get_user_by_id(user_id=user_id)

# Get the tool
# TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent
# TODO: That would probably imply that agent_state is incorrectly configured
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
self.tool = tool_object
else:
# Get the tool via name
# TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent
# TODO: That would probably imply that agent_state is incorrectly configured
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)

self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.force_recreate = force_recreate
Expand Down
116 changes: 116 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,122 @@ def test_get_messages_letta_format(server, user_id, agent_id):
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)


EXAMPLE_TOOL_SOURCE = '''
def ingest(message: str):
"""
Ingest a message into the system.
Args:
message (str): The message to ingest into the system.
Returns:
str: The result of ingesting the message.
"""
return f"Ingested message {message}"
'''


EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = '''
def util_do_nothing():
"""
A util function that does nothing.
Returns:
str: Dummy output.
"""
print("I'm a distractor")
def ingest(message: str):
"""
Ingest a message into the system.
Args:
message (str): The message to ingest into the system.
Returns:
str: The result of ingesting the message.
"""
util_do_nothing()
return f"Ingested message {message}"
'''


def test_tool_run(server, user_id, agent_id):
"""Test that the server can run tools"""

result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Hello, world!"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Hello, world!", result.function_return

result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return

result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"bad_arg": "oh no"}),
# tool_name="ingest",
)
print(result)
assert result.status == "error"
assert "Error" in result.function_return, result.function_return
assert "missing 1 required positional argument" in result.function_return, result.function_return

# Test that we can still pull the tool out by default (pulls that last tool in the source)
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return

# Test that we can pull the tool out by name
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return

# Test that we can pull a different tool out by name
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({}),
tool_name="util_do_nothing",
)
print(result)
assert result.status == "success"
assert result.function_return == str(None), result.function_return


def test_composio_client_simple(server):
apps = server.get_composio_apps()
# Assert there's some amount of apps returned
Expand Down

0 comments on commit 9417b11

Please sign in to comment.