Skip to content

Commit

Permalink
record tool (#333)
Browse files Browse the repository at this point in the history
* record tool

* remove shared code, fix tests
  • Loading branch information
bboynton97 authored Aug 5, 2024
1 parent 49a748b commit b0bafd0
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 3 deletions.
2 changes: 1 addition & 1 deletion agentops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .client import Client
from .event import Event, ActionEvent, LLMEvent, ToolEvent, ErrorEvent
from .decorators import record_function, track_agent
from .decorators import record_function, track_agent, record_tool
from .helpers import check_agentops_update
from .log_config import logger
from .session import Session
Expand Down
140 changes: 138 additions & 2 deletions agentops/decorators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import inspect
import functools
import inspect
from typing import Optional, Union
from uuid import uuid4

from .event import ActionEvent, ErrorEvent
from .event import ActionEvent, ErrorEvent, ToolEvent
from .helpers import check_call_stack_for_agent_id, get_ISO_time
from .session import Session
from .client import Client
Expand Down Expand Up @@ -146,6 +146,142 @@ def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
return decorator


def record_tool(tool_name: str):
"""
Decorator to record a tool use event before and after a function call.
Usage:
- Tools: Records function parameters and return statements of the
function being decorated. Additionally, timing information about
the action is recorded
Args:
tool_name (str): The name of the event to record.
"""

def decorator(func):
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ToolEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
name=tool_name,
)

try:
returns = await func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

event.returns = returns

# NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now
# TODO: check if screenshot is the url string we expect it to be? And not e.g. "True"
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

# Re-raise the exception
raise

return returns

return async_wrapper
else:

@functools.wraps(func)
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ToolEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
name=tool_name,
)

try:
returns = func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

event.returns = returns

if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

# Re-raise the exception
raise

return returns

return sync_wrapper

return decorator


def track_agent(name: Union[str, None] = None):
def decorator(obj):
if name:
Expand Down
228 changes: 228 additions & 0 deletions tests/test_record_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import pytest
import requests_mock
import time
import agentops
from agentops.decorators import record_tool
from datetime import datetime

from agentops.helpers import clear_singletons
import contextlib

jwts = ["some_jwt", "some_jwt2", "some_jwt3"]


@pytest.fixture(autouse=True)
def setup_teardown():
clear_singletons()
yield
agentops.end_all_sessions() # teardown part


@contextlib.contextmanager
@pytest.fixture(autouse=True)
def mock_req():
with requests_mock.Mocker() as m:
url = "https://api.agentops.ai"
m.post(url + "/v2/create_events", text="ok")

# Use iter to create an iterator that can return the jwt values
jwt_tokens = iter(jwts)

# Use an inner function to change the response for each request
def create_session_response(request, context):
context.status_code = 200
return {"status": "success", "jwt": next(jwt_tokens)}

m.post(url + "/v2/create_session", json=create_session_response)
m.post(url + "/v2/update_session", json={"status": "success", "token_cost": 5})
m.post(url + "/v2/developer_errors", text="ok")

yield m


class TestRecordTool:
def setup_method(self):
self.url = "https://api.agentops.ai"
self.api_key = "11111111-1111-4111-8111-111111111111"
self.tool_name = "test_tool_name"
agentops.init(self.api_key, max_wait_time=5, auto_start_session=False)

def test_record_tool_decorator(self, mock_req):
agentops.start_session()

@record_tool(tool_name=self.tool_name)
def add_two(x, y):
return x + y

# Act
add_two(3, 4)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 2
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
request_json = mock_req.last_request.json()
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 3, "y": 4}
assert request_json["events"][0]["returns"] == 7

agentops.end_session(end_state="Success")

def test_record_tool_decorator_multiple(self, mock_req):
agentops.start_session()

# Arrange
@record_tool(tool_name=self.tool_name)
def add_three(x, y, z=3):
return x + y + z

# Act
add_three(1, 2)
time.sleep(0.1)
add_three(1, 2)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 3
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
request_json = mock_req.last_request.json()
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3}
assert request_json["events"][0]["returns"] == 6

agentops.end_session(end_state="Success")

@pytest.mark.asyncio
async def test_async_tool_call(self, mock_req):
agentops.start_session()

@record_tool(self.tool_name)
async def async_add(x, y):
time.sleep(0.1)
return x + y

# Act
result = await async_add(3, 4)
time.sleep(0.1)

# Assert
assert result == 7
# Assert
assert len(mock_req.request_history) == 2
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
request_json = mock_req.last_request.json()
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 3, "y": 4}
assert request_json["events"][0]["returns"] == 7

init = datetime.fromisoformat(request_json["events"][0]["init_timestamp"])
end = datetime.fromisoformat(request_json["events"][0]["end_timestamp"])

assert (end - init).total_seconds() >= 0.1

agentops.end_session(end_state="Success")

def test_multiple_sessions_sync(self, mock_req):
session_1 = agentops.start_session()
session_2 = agentops.start_session()
assert session_1 is not None
assert session_2 is not None

# Arrange
@record_tool(tool_name=self.tool_name)
def add_three(x, y, z=3):
return x + y + z

# Act
add_three(1, 2, session=session_1)
time.sleep(0.1)
add_three(1, 2, session=session_2)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 4

request_json = mock_req.last_request.json()
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2"
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3}
assert request_json["events"][0]["returns"] == 6

second_last_request_json = mock_req.request_history[-2].json()
assert (
mock_req.request_history[-2].headers["X-Agentops-Api-Key"] == self.api_key
)
assert (
mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt"
)
assert second_last_request_json["events"][0]["name"] == self.tool_name
assert second_last_request_json["events"][0]["params"] == {
"x": 1,
"y": 2,
"z": 3,
}
assert second_last_request_json["events"][0]["returns"] == 6

session_1.end_session(end_state="Success")
session_2.end_session(end_state="Success")

@pytest.mark.asyncio
async def test_multiple_sessions_async(self, mock_req):
session_1 = agentops.start_session()
session_2 = agentops.start_session()
assert session_1 is not None
assert session_2 is not None

# Arrange
@record_tool(tool_name=self.tool_name)
async def async_add(x, y):
time.sleep(0.1)
return x + y

# Act
await async_add(1, 2, session=session_1)
time.sleep(0.1)
await async_add(1, 2, session=session_2)
time.sleep(0.1)

# Assert
assert len(mock_req.request_history) == 4

request_json = mock_req.last_request.json()
assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key
assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2"
assert request_json["events"][0]["name"] == self.tool_name
assert request_json["events"][0]["params"] == {"x": 1, "y": 2}
assert request_json["events"][0]["returns"] == 3

second_last_request_json = mock_req.request_history[-2].json()
assert (
mock_req.request_history[-2].headers["X-Agentops-Api-Key"] == self.api_key
)
assert (
mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt"
)
assert second_last_request_json["events"][0]["name"] == self.tool_name
assert second_last_request_json["events"][0]["params"] == {
"x": 1,
"y": 2,
}
assert second_last_request_json["events"][0]["returns"] == 3

session_1.end_session(end_state="Success")
session_2.end_session(end_state="Success")

def test_require_session_if_multiple(self):
session_1 = agentops.start_session()
session_2 = agentops.start_session()

# Arrange
@record_tool(tool_name=self.tool_name)
def add_two(x, y):
time.sleep(0.1)
return x + y

with pytest.raises(ValueError):
# Act
add_two(1, 2)

0 comments on commit b0bafd0

Please sign in to comment.