Skip to content

Commit

Permalink
record tool
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Aug 5, 2024
1 parent 49a748b commit 2fcc658
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 128 deletions.
2 changes: 1 addition & 1 deletion agentops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .client import Client
from .event import Event, ActionEvent, LLMEvent, ToolEvent, ErrorEvent
from .decorators import record_function, track_agent
from .decorators import record_function, track_agent, record_tool
from .helpers import check_agentops_update
from .log_config import logger
from .session import Session
Expand Down
196 changes: 69 additions & 127 deletions agentops/decorators.py
Original file line number Diff line number Diff line change
@@ -1,147 +1,89 @@
import inspect
import functools
from typing import Optional, Union
from uuid import uuid4

from .event import ActionEvent, ErrorEvent
from .event import ActionEvent, ErrorEvent, ToolEvent
from .helpers import check_call_stack_for_agent_id, get_ISO_time
from .session import Session
from .client import Client
from .log_config import logger


def record_function(event_name: str):
"""
Decorator to record an event before and after a function call.
Usage:
- Actions: Records function parameters and return statements of the
function being decorated. Additionally, timing information about
the action is recorded
Args:
event_name (str): The name of the event to record.
"""
def __record_action(func, event_name: str, EventClass):
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
return __process_func(func, args, kwargs, session, event_name, EventClass)

return sync_wrapper


def __record_action_async(func, event_name: str, EventClass):
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
return await __process_func(func, args, kwargs, session, event_name, EventClass)

return async_wrapper


def __process_func(func, args, kwargs, session, event_name, EventClass):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None and Client().is_multi_session:
raise ValueError(
f"If multiple sessions exists, `session` is a required parameter in the function"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
arg_values = { # Get default values
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
arg_values.update(dict(zip(arg_names, args))) # Update with positional arguments
arg_values.update(kwargs)

event = EventClass(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
action_type=event_name,
tool_name=event_name,
)

try:
returns = func(*args, **kwargs)
if isinstance(
returns, tuple
): # If the function returns multiple values, record them all in the same event
returns = list(returns)
event.returns = returns
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore
event.end_timestamp = get_ISO_time()
session.record(event) if session else Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))
raise

return returns


def record_function(event_name: str):
def decorator(func):
if inspect.iscoroutinefunction(func):

@functools.wraps(func)
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ActionEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
action_type=event_name,
)

try:
returns = await func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

event.returns = returns

# NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now
# TODO: check if screenshot is the url string we expect it to be? And not e.g. "True"
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

# Re-raise the exception
raise

return returns

return async_wrapper
return __record_action_async(func, event_name, ActionEvent)
else:
return __record_action(func, event_name, ActionEvent)

@functools.wraps(func)
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ActionEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
action_type=event_name,
)

try:
returns = func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

event.returns = returns

if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

# Re-raise the exception
raise
return decorator

return returns

return sync_wrapper
def record_tool(tool_name: str):
def decorator(func):
if inspect.iscoroutinefunction(func):
return __record_action_async(func, tool_name, ToolEvent)
else:
return __record_action(func, tool_name, ToolEvent)

return decorator

Expand Down
Loading

0 comments on commit 2fcc658

Please sign in to comment.