Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add POST route for testing tool execution via tool_id #2139

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions letta/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,15 @@ class ToolUpdate(LettaBase):
class Config:
extra = "ignore" # Allows extra fields without validation errors
# TODO: Remove this, and clean usage of ToolUpdate everywhere else


class ToolRun(LettaBase):
id: str = Field(..., description="The ID of the tool to run.")
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")


class ToolRunFromSource(LettaBase):
args: str = Field(..., description="The arguments to pass to the tool (as stringified JSON).")
name: Optional[str] = Field(..., description="The name of the tool to run.")
source_code: str = Field(None, description="The source code of the function.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
37 changes: 36 additions & 1 deletion letta/server/rest_api/routers/v1/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from letta.errors import LettaToolCreateError
from letta.orm.errors import UniqueConstraintViolationError
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.schemas.letta_message import FunctionReturn
from letta.schemas.tool import Tool, ToolCreate, ToolRun, ToolRunFromSource, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer

Expand Down Expand Up @@ -156,3 +157,37 @@ def add_base_tools(
"""
actor = server.get_user_or_default(user_id=user_id)
return server.tool_manager.add_base_tools(actor=actor)


@router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool")
def run_tool(
server: SyncServer = Depends(get_letta_server),
request: ToolRun = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Run an existing tool on provided arguments
"""
actor = server.get_user_or_default(user_id=user_id)

return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)


@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")
def run_tool_from_source(
server: SyncServer = Depends(get_letta_server),
request: ToolRunFromSource = Body(...),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attempt to build a tool from source, then run it on the provided arguments
"""
actor = server.get_user_or_default(user_id=user_id)

return server.run_tool_from_source(
tool_source=request.source_code,
tool_source_type=request.source_type,
tool_args=request.args,
tool_name=request.name,
user_id=actor.id,
)
110 changes: 108 additions & 2 deletions letta/server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# inspecting tools
import json
import os
import traceback
import warnings
Expand Down Expand Up @@ -54,7 +55,7 @@
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message import FunctionReturn, LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary,
Expand All @@ -76,9 +77,10 @@
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.utils import create_random_username, json_dumps, json_loads
from letta.utils import create_random_username, get_utc_time, json_dumps, json_loads

logger = get_logger(__name__)

Expand Down Expand Up @@ -1750,3 +1752,107 @@ def get_agent_block_by_label(self, user_id: str, agent_id: str, label: str) -> B
if block.label == label:
return block
return None

def run_tool(self, tool_id: str, tool_args: str, user_id: str) -> FunctionReturn:
"""Run a tool using the sandbox and return the result"""

try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string for tool_args")

# Get the tool by ID
user = self.user_manager.get_user_by_id(user_id=user_id)
tool = self.tool_manager.get_tool_by_id(tool_id=tool_id, actor=user)
if tool.name is None:
raise ValueError(f"Tool with id {tool_id} does not have a name")

# TODO eventually allow using agent state in tools
agent_state = None

try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id).run(agent_state=agent_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit of a bit, but can't this be the same code as run_tool_from_source once the Tool object is ready?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah this code I kind of stopped editing in the PR since I wasn't sure if we wanted to delete it?

e.g. if we're not going to use run_tool and only use run_tool_from_source, shouldn't we just add the later and kill the former (and add it if we need it later?)

if sandbox_run_result is None:
raise ValueError(f"Tool with id {tool_id} returned execution with None")
function_response = str(sandbox_run_result.func_return)

return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="success",
function_return=function_response,
)
except Exception as e:
# same as agent.py
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT

error_msg = f"Error executing tool {tool.name}: {e}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]

return FunctionReturn(
id="null",
cpacker marked this conversation as resolved.
Show resolved Hide resolved
function_call_id="null",
date=get_utc_time(),
status="error",
function_return=error_msg,
)

def run_tool_from_source(
self,
user_id: str,
tool_args: str,
tool_source: str,
tool_source_type: Optional[str] = None,
tool_name: Optional[str] = None,
) -> FunctionReturn:
"""Run a tool from source code"""

try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
raise ValueError("Invalid JSON string for tool_args")

if tool_source_type is not None and tool_source_type != "python":
raise ValueError("Only Python source code is supported at this time")

# NOTE: we're creating a floating Tool object and NOT persiting to DB
tool = Tool(
name=tool_name,
source_code=tool_source,
)
assert tool.name is not None, "Failed to create tool object"

# TODO eventually allow using agent state in tools
agent_state = None

# Next, attempt to run the tool with the sandbox
try:
sandbox_run_result = ToolExecutionSandbox(tool.name, tool_args_dict, user_id, tool_object=tool).run(agent_state=agent_state)
if sandbox_run_result is None:
raise ValueError(f"Tool with id {tool.id} returned execution with None")
function_response = str(sandbox_run_result.func_return)

return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="success",
function_return=function_response,
)
except Exception as e:
# same as agent.py
from letta.constants import MAX_ERROR_MESSAGE_CHAR_LIMIT

error_msg = f"Error executing tool {tool.name}: {e}"
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]

return FunctionReturn(
id="null",
function_call_id="null",
date=get_utc_time(),
status="error",
function_return=error_msg,
)
23 changes: 14 additions & 9 deletions letta/services/tool_execution_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult, SandboxType
from letta.schemas.tool import Tool
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
Expand All @@ -27,7 +28,7 @@ class ToolExecutionSandbox:
# We make this a long random string to avoid collisions with any variables in the user's code
LOCAL_SANDBOX_RESULT_VAR_NAME = "result_ZQqiequkcFwRwwGQMqkt"

def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False):
def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=False, tool_object: Optional[Tool] = None):
self.tool_name = tool_name
self.args = args

Expand All @@ -36,14 +37,18 @@ def __init__(self, tool_name: str, args: dict, user_id: str, force_recreate=Fals
# agent_state is the state of the agent that invoked this run
self.user = UserManager().get_user_by_id(user_id=user_id)

# Get the tool
# TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent
# TODO: That would probably imply that agent_state is incorrectly configured
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)
# If a tool object is provided, we use it directly, otherwise pull via name
if tool_object is not None:
self.tool = tool_object
else:
# Get the tool via name
# TODO: So in theory, it's possible this retrieves a tool not provisioned to the agent
# TODO: That would probably imply that agent_state is incorrectly configured
self.tool = ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user)
if not self.tool:
raise ValueError(
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
)

self.sandbox_config_manager = SandboxConfigManager(tool_settings)
self.force_recreate = force_recreate
Expand Down
117 changes: 116 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.memory import ChatMemory
from letta.schemas.source import Source
from letta.server.server import SyncServer

Expand Down Expand Up @@ -540,3 +539,119 @@ def _test_get_messages_letta_format(
def test_get_messages_letta_format(server, user_id, agent_id):
for reverse in [False, True]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)


EXAMPLE_TOOL_SOURCE = '''
def ingest(message: str):
"""
Ingest a message into the system.

Args:
message (str): The message to ingest into the system.

Returns:
str: The result of ingesting the message.
"""
return f"Ingested message {message}"

'''


EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR = '''
def util_do_nothing():
"""
A util function that does nothing.

Returns:
str: Dummy output.
"""
print("I'm a distractor")

def ingest(message: str):
"""
Ingest a message into the system.

Args:
message (str): The message to ingest into the system.

Returns:
str: The result of ingesting the message.
"""
util_do_nothing()
return f"Ingested message {message}"

'''


def test_tool_run(server, user_id, agent_id):
"""Test that the server can run tools"""

result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Hello, world!"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Hello, world!", result.function_return

result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return

result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE,
tool_source_type="python",
tool_args=json.dumps({"bad_arg": "oh no"}),
# tool_name="ingest",
)
print(result)
assert result.status == "error"
assert "Error" in result.function_return, result.function_return
assert "missing 1 required positional argument" in result.function_return, result.function_return

# Test that we can still pull the tool out by default (pulls that last tool in the source)
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
# tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return

# Test that we can pull the tool out by name
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({"message": "Well well well"}),
tool_name="ingest",
)
print(result)
assert result.status == "success"
assert result.function_return == "Ingested message Well well well", result.function_return

# Test that we can pull a different tool out by name
result = server.run_tool_from_source(
user_id=user_id,
tool_source=EXAMPLE_TOOL_SOURCE_WITH_DISTRACTOR,
tool_source_type="python",
tool_args=json.dumps({}),
tool_name="util_do_nothing",
)
print(result)
assert result.status == "success"
assert result.function_return == str(None), result.function_return
Loading