From 8821ce2fde98a7ce2d966398f47951a55cca50a7 Mon Sep 17 00:00:00 2001 From: Braelyn Boynton Date: Mon, 5 Aug 2024 13:40:50 -0700 Subject: [PATCH] remove shared code, fix tests --- agentops/decorators.py | 320 ++++++++++++++++++++++++++++++-------- tests/test_record_tool.py | 24 +-- 2 files changed, 269 insertions(+), 75 deletions(-) diff --git a/agentops/decorators.py b/agentops/decorators.py index e6b071c8..68c54a26 100644 --- a/agentops/decorators.py +++ b/agentops/decorators.py @@ -1,3 +1,4 @@ +import functools import inspect from typing import Optional, Union from uuid import uuid4 @@ -9,81 +10,274 @@ from .log_config import logger -def __record_action(func, event_name: str, EventClass): - def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): - return __process_func(func, args, kwargs, session, event_name, EventClass) - - return sync_wrapper - - -def __record_action_async(func, event_name: str, EventClass): - async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): - return await __process_func(func, args, kwargs, session, event_name, EventClass) - - return async_wrapper - - -def __process_func(func, args, kwargs, session, event_name, EventClass): - init_time = get_ISO_time() - if "session" in kwargs.keys(): - del kwargs["session"] - if session is None and Client().is_multi_session: - raise ValueError( - f"If multiple sessions exists, `session` is a required parameter in the function" - ) - func_args = inspect.signature(func).parameters - arg_names = list(func_args.keys()) - arg_values = { # Get default values - name: func_args[name].default - for name in arg_names - if func_args[name].default is not inspect._empty - } - arg_values.update(dict(zip(arg_names, args))) # Update with positional arguments - arg_values.update(kwargs) - - event = EventClass( - params=arg_values, - init_timestamp=init_time, - agent_id=check_call_stack_for_agent_id(), - action_type=event_name, - tool_name=event_name, - ) - - try: - returns = func(*args, **kwargs) - if isinstance( - returns, tuple - ): # If the function returns multiple values, record them all in the same event - returns = list(returns) - event.returns = returns - if hasattr(returns, "screenshot"): - event.screenshot = returns.screenshot # type: ignore - event.end_timestamp = get_ISO_time() - session.record(event) if session else Client().record(event) - - except Exception as e: - Client().record(ErrorEvent(trigger_event=event, exception=e)) - raise - - return returns - - def record_function(event_name: str): + """ + Decorator to record an event before and after a function call. + Usage: + - Actions: Records function parameters and return statements of the + function being decorated. Additionally, timing information about + the action is recorded + Args: + event_name (str): The name of the event to record. + """ + def decorator(func): if inspect.iscoroutinefunction(func): - return __record_action_async(func, event_name, ActionEvent) + + @functools.wraps(func) + async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): + init_time = get_ISO_time() + if "session" in kwargs.keys(): + del kwargs["session"] + if session is None: + if Client().is_multi_session: + raise ValueError( + "If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function" + ) + func_args = inspect.signature(func).parameters + arg_names = list(func_args.keys()) + # Get default values + arg_values = { + name: func_args[name].default + for name in arg_names + if func_args[name].default is not inspect._empty + } + # Update with positional arguments + arg_values.update(dict(zip(arg_names, args))) + arg_values.update(kwargs) + + event = ActionEvent( + params=arg_values, + init_timestamp=init_time, + agent_id=check_call_stack_for_agent_id(), + action_type=event_name, + ) + + try: + returns = await func(*args, **kwargs) + + # If the function returns multiple values, record them all in the same event + if isinstance(returns, tuple): + returns = list(returns) + + event.returns = returns + + # NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now + # TODO: check if screenshot is the url string we expect it to be? And not e.g. "True" + if hasattr(returns, "screenshot"): + event.screenshot = returns.screenshot # type: ignore + + event.end_timestamp = get_ISO_time() + + if session: + session.record(event) + else: + Client().record(event) + + except Exception as e: + Client().record(ErrorEvent(trigger_event=event, exception=e)) + + # Re-raise the exception + raise + + return returns + + return async_wrapper else: - return __record_action(func, event_name, ActionEvent) + + @functools.wraps(func) + def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): + init_time = get_ISO_time() + if "session" in kwargs.keys(): + del kwargs["session"] + if session is None: + if Client().is_multi_session: + raise ValueError( + "If multiple sessions exists, `session` is a required parameter in the function decorated by @record_function" + ) + func_args = inspect.signature(func).parameters + arg_names = list(func_args.keys()) + # Get default values + arg_values = { + name: func_args[name].default + for name in arg_names + if func_args[name].default is not inspect._empty + } + # Update with positional arguments + arg_values.update(dict(zip(arg_names, args))) + arg_values.update(kwargs) + + event = ActionEvent( + params=arg_values, + init_timestamp=init_time, + agent_id=check_call_stack_for_agent_id(), + action_type=event_name, + ) + + try: + returns = func(*args, **kwargs) + + # If the function returns multiple values, record them all in the same event + if isinstance(returns, tuple): + returns = list(returns) + + event.returns = returns + + if hasattr(returns, "screenshot"): + event.screenshot = returns.screenshot # type: ignore + + event.end_timestamp = get_ISO_time() + + if session: + session.record(event) + else: + Client().record(event) + + except Exception as e: + Client().record(ErrorEvent(trigger_event=event, exception=e)) + + # Re-raise the exception + raise + + return returns + + return sync_wrapper return decorator def record_tool(tool_name: str): + """ + Decorator to record a tool use event before and after a function call. + Usage: + - Tools: Records function parameters and return statements of the + function being decorated. Additionally, timing information about + the action is recorded + Args: + tool_name (str): The name of the event to record. + """ + def decorator(func): if inspect.iscoroutinefunction(func): - return __record_action_async(func, tool_name, ToolEvent) + + @functools.wraps(func) + async def async_wrapper(*args, session: Optional[Session] = None, **kwargs): + init_time = get_ISO_time() + if "session" in kwargs.keys(): + del kwargs["session"] + if session is None: + if Client().is_multi_session: + raise ValueError( + "If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool" + ) + func_args = inspect.signature(func).parameters + arg_names = list(func_args.keys()) + # Get default values + arg_values = { + name: func_args[name].default + for name in arg_names + if func_args[name].default is not inspect._empty + } + # Update with positional arguments + arg_values.update(dict(zip(arg_names, args))) + arg_values.update(kwargs) + + event = ToolEvent( + params=arg_values, + init_timestamp=init_time, + agent_id=check_call_stack_for_agent_id(), + name=tool_name, + ) + + try: + returns = await func(*args, **kwargs) + + # If the function returns multiple values, record them all in the same event + if isinstance(returns, tuple): + returns = list(returns) + + event.returns = returns + + # NOTE: Will likely remove in future since this is tightly coupled. Adding it to see how useful we find it for now + # TODO: check if screenshot is the url string we expect it to be? And not e.g. "True" + if hasattr(returns, "screenshot"): + event.screenshot = returns.screenshot # type: ignore + + event.end_timestamp = get_ISO_time() + + if session: + session.record(event) + else: + Client().record(event) + + except Exception as e: + Client().record(ErrorEvent(trigger_event=event, exception=e)) + + # Re-raise the exception + raise + + return returns + + return async_wrapper else: - return __record_action(func, tool_name, ToolEvent) + + @functools.wraps(func) + def sync_wrapper(*args, session: Optional[Session] = None, **kwargs): + init_time = get_ISO_time() + if "session" in kwargs.keys(): + del kwargs["session"] + if session is None: + if Client().is_multi_session: + raise ValueError( + "If multiple sessions exists, `session` is a required parameter in the function decorated by @record_tool" + ) + func_args = inspect.signature(func).parameters + arg_names = list(func_args.keys()) + # Get default values + arg_values = { + name: func_args[name].default + for name in arg_names + if func_args[name].default is not inspect._empty + } + # Update with positional arguments + arg_values.update(dict(zip(arg_names, args))) + arg_values.update(kwargs) + + event = ToolEvent( + params=arg_values, + init_timestamp=init_time, + agent_id=check_call_stack_for_agent_id(), + name=tool_name, + ) + + try: + returns = func(*args, **kwargs) + + # If the function returns multiple values, record them all in the same event + if isinstance(returns, tuple): + returns = list(returns) + + event.returns = returns + + if hasattr(returns, "screenshot"): + event.screenshot = returns.screenshot # type: ignore + + event.end_timestamp = get_ISO_time() + + if session: + session.record(event) + else: + Client().record(event) + + except Exception as e: + Client().record(ErrorEvent(trigger_event=event, exception=e)) + + # Re-raise the exception + raise + + return returns + + return sync_wrapper return decorator diff --git a/tests/test_record_tool.py b/tests/test_record_tool.py index a0cf1d7c..d6f8fe4e 100644 --- a/tests/test_record_tool.py +++ b/tests/test_record_tool.py @@ -40,14 +40,14 @@ def create_session_response(request, context): yield m -class TestRecordAction: +class TestRecordTool: def setup_method(self): self.url = "https://api.agentops.ai" self.api_key = "11111111-1111-4111-8111-111111111111" self.tool_name = "test_tool_name" agentops.init(self.api_key, max_wait_time=5, auto_start_session=False) - def test_record_function_decorator(self, mock_req): + def test_record_tool_decorator(self, mock_req): agentops.start_session() @record_tool(tool_name=self.tool_name) @@ -62,13 +62,13 @@ def add_two(x, y): assert len(mock_req.request_history) == 2 assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key request_json = mock_req.last_request.json() - assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["name"] == self.tool_name assert request_json["events"][0]["params"] == {"x": 3, "y": 4} assert request_json["events"][0]["returns"] == 7 agentops.end_session(end_state="Success") - def test_record_function_decorator_multiple(self, mock_req): + def test_record_tool_decorator_multiple(self, mock_req): agentops.start_session() # Arrange @@ -86,17 +86,17 @@ def add_three(x, y, z=3): assert len(mock_req.request_history) == 3 assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key request_json = mock_req.last_request.json() - assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["name"] == self.tool_name assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3} assert request_json["events"][0]["returns"] == 6 agentops.end_session(end_state="Success") @pytest.mark.asyncio - async def test_async_function_call(self, mock_req): + async def test_async_tool_call(self, mock_req): agentops.start_session() - @record_function(self.tool_name) + @record_tool(self.tool_name) async def async_add(x, y): time.sleep(0.1) return x + y @@ -111,7 +111,7 @@ async def async_add(x, y): assert len(mock_req.request_history) == 2 assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key request_json = mock_req.last_request.json() - assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["name"] == self.tool_name assert request_json["events"][0]["params"] == {"x": 3, "y": 4} assert request_json["events"][0]["returns"] == 7 @@ -145,7 +145,7 @@ def add_three(x, y, z=3): request_json = mock_req.last_request.json() assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2" - assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["name"] == self.tool_name assert request_json["events"][0]["params"] == {"x": 1, "y": 2, "z": 3} assert request_json["events"][0]["returns"] == 6 @@ -156,7 +156,7 @@ def add_three(x, y, z=3): assert ( mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt" ) - assert second_last_request_json["events"][0]["tool_name"] == self.tool_name + assert second_last_request_json["events"][0]["name"] == self.tool_name assert second_last_request_json["events"][0]["params"] == { "x": 1, "y": 2, @@ -192,7 +192,7 @@ async def async_add(x, y): request_json = mock_req.last_request.json() assert mock_req.last_request.headers["X-Agentops-Api-Key"] == self.api_key assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt2" - assert request_json["events"][0]["tool_name"] == self.tool_name + assert request_json["events"][0]["name"] == self.tool_name assert request_json["events"][0]["params"] == {"x": 1, "y": 2} assert request_json["events"][0]["returns"] == 3 @@ -203,7 +203,7 @@ async def async_add(x, y): assert ( mock_req.request_history[-2].headers["Authorization"] == f"Bearer some_jwt" ) - assert second_last_request_json["events"][0]["tool_name"] == self.tool_name + assert second_last_request_json["events"][0]["name"] == self.tool_name assert second_last_request_json["events"][0]["params"] == { "x": 1, "y": 2,