Skip to content

Commit

Permalink
remove shared code, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bboynton97 committed Aug 5, 2024
1 parent 2fcc658 commit 8821ce2
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 75 deletions.
320 changes: 257 additions & 63 deletions agentops/decorators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import inspect
from typing import Optional, Union
from uuid import uuid4
Expand All @@ -9,81 +10,274 @@
from .log_config import logger


def __record_action(func, event_name: str, EventClass):
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
return __process_func(func, args, kwargs, session, event_name, EventClass)

return sync_wrapper


def __record_action_async(func, event_name: str, EventClass):
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
return await __process_func(func, args, kwargs, session, event_name, EventClass)

return async_wrapper


def __process_func(func, args, kwargs, session, event_name, EventClass):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None and Client().is_multi_session:
raise ValueError(
f"If multiple sessions exists, `session` is a required parameter in the function"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
arg_values = { # Get default values
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
arg_values.update(dict(zip(arg_names, args))) # Update with positional arguments
arg_values.update(kwargs)

event = EventClass(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
action_type=event_name,
tool_name=event_name,
)

try:
returns = func(*args, **kwargs)
if isinstance(
returns, tuple
): # If the function returns multiple values, record them all in the same event
returns = list(returns)
event.returns = returns
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore
event.end_timestamp = get_ISO_time()
session.record(event) if session else Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))
raise

return returns


def record_function(event_name: str):
"""
Decorator to record an event before and after a function call.
Usage:
- Actions: Records function parameters and return statements of the
function being decorated. Additionally, timing information about
the action is recorded
Args:
event_name (str): The name of the event to record.
"""

def decorator(func):
if inspect.iscoroutinefunction(func):
return __record_action_async(func, event_name, ActionEvent)

@functools.wraps(func)
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ActionEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
action_type=event_name,
)

try:
returns = await func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

event.returns = returns

# NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now
# TODO: check if screenshot is the url string we expect it to be? And not e.g. "True"
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

# Re-raise the exception
raise

return returns

return async_wrapper
else:
return __record_action(func, event_name, ActionEvent)

@functools.wraps(func)
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ActionEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
action_type=event_name,
)

try:
returns = func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

event.returns = returns

if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

# Re-raise the exception
raise

return returns

return sync_wrapper

return decorator


def record_tool(tool_name: str):
"""
Decorator to record a tool use event before and after a function call.
Usage:
- Tools: Records function parameters and return statements of the
function being decorated. Additionally, timing information about
the action is recorded
Args:
tool_name (str): The name of the event to record.
"""

def decorator(func):
if inspect.iscoroutinefunction(func):
return __record_action_async(func, tool_name, ToolEvent)

@functools.wraps(func)
async def async_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]

Check warning on line 167 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L167

Added line #L167 was not covered by tests
if session is None:
if Client().is_multi_session:
raise ValueError(

Check warning on line 170 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L170

Added line #L170 was not covered by tests
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ToolEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
name=tool_name,
)

try:
returns = await func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

Check warning on line 197 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L197

Added line #L197 was not covered by tests

event.returns = returns

# NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now
# TODO: check if screenshot is the url string we expect it to be? And not e.g. "True"
if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

Check warning on line 204 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L204

Added line #L204 was not covered by tests

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

Check warning on line 214 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L213-L214

Added lines #L213 - L214 were not covered by tests

# Re-raise the exception
raise

Check warning on line 217 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L217

Added line #L217 was not covered by tests

return returns

return async_wrapper
else:
return __record_action(func, tool_name, ToolEvent)

@functools.wraps(func)
def sync_wrapper(*args, session: Optional[Session] = None, **kwargs):
init_time = get_ISO_time()
if "session" in kwargs.keys():
del kwargs["session"]

Check warning on line 228 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L228

Added line #L228 was not covered by tests
if session is None:
if Client().is_multi_session:
raise ValueError(
"If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool"
)
func_args = inspect.signature(func).parameters
arg_names = list(func_args.keys())
# Get default values
arg_values = {
name: func_args[name].default
for name in arg_names
if func_args[name].default is not inspect._empty
}
# Update with positional arguments
arg_values.update(dict(zip(arg_names, args)))
arg_values.update(kwargs)

event = ToolEvent(
params=arg_values,
init_timestamp=init_time,
agent_id=check_call_stack_for_agent_id(),
name=tool_name,
)

try:
returns = func(*args, **kwargs)

# If the function returns multiple values, record them all in the same event
if isinstance(returns, tuple):
returns = list(returns)

Check warning on line 258 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L258

Added line #L258 was not covered by tests

event.returns = returns

if hasattr(returns, "screenshot"):
event.screenshot = returns.screenshot # type: ignore

Check warning on line 263 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L263

Added line #L263 was not covered by tests

event.end_timestamp = get_ISO_time()

if session:
session.record(event)
else:
Client().record(event)

except Exception as e:
Client().record(ErrorEvent(trigger_event=event, exception=e))

Check warning on line 273 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L272-L273

Added lines #L272 - L273 were not covered by tests

# Re-raise the exception
raise

Check warning on line 276 in agentops/decorators.py

View check run for this annotation

Codecov / codecov/patch

agentops/decorators.py#L276

Added line #L276 was not covered by tests

return returns

return sync_wrapper

return decorator

Expand Down
Loading

0 comments on commit 8821ce2

Please sign in to comment.