Skip to content

Commit

Permalink
Overhaul tools (autogenhub#53)
Browse files Browse the repository at this point in the history
* Overhaul tools

* add a simple test

* mypy fixes

* format
  • Loading branch information
jackgerrits authored Jun 5, 2024
1 parent 837c388 commit ab420cc
Show file tree
Hide file tree
Showing 36 changed files with 450 additions and 228 deletions.
2 changes: 1 addition & 1 deletion examples/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from agnext.application import SingleThreadedAgentRuntime
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import Agent, AgentRuntime, CancellationToken


Expand Down
49 changes: 32 additions & 17 deletions examples/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import os
from typing import Annotated, Callable
from typing import Callable

import openai
from agnext.application import (
Expand All @@ -13,19 +13,44 @@
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
from agnext.chat.types import TextMessage
from agnext.components.function_executor._impl.in_process_function_executor import (
InProcessFunctionExecutor,
)
from agnext.components.models import OpenAI, SystemMessage
from agnext.core import Agent, AgentRuntime
from agnext.components.tools import BaseTool
from agnext.core import Agent, AgentRuntime, CancellationToken
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from tavily import TavilyClient
from pydantic import BaseModel, Field
from tavily import TavilyClient # type: ignore
from typing_extensions import Any, override

logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)


class SearchQuery(BaseModel):
query: str = Field(description="The search query.")


class SearchResult(BaseModel):
result: str = Field(description="The search results.")


class SearchTool(BaseTool[SearchQuery, SearchResult]):
def __init__(self) -> None:
super().__init__(
args_type=SearchQuery,
return_type=SearchResult,
name="search",
description="Search the web.",
)

async def run(self, args: SearchQuery, cancellation_token: CancellationToken) -> SearchResult:
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY")) # type: ignore
result = await asyncio.create_task(client.search(args.query)) # type: ignore
if result:
return SearchResult(result=json.dumps(result, indent=2, ensure_ascii=False))

return SearchResult(result="No results found.")


class LoggingHandler(DefaultInterventionHandler): # type: ignore
send_color = "\033[31m"
response_color = "\033[34m"
Expand Down Expand Up @@ -76,16 +101,6 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
thread_id=tester_oai_thread.id,
)

def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The search results."]:
"""Search the web."""
client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
result = client.search(query) # type: ignore
if result:
return json.dumps(result, indent=2, ensure_ascii=False) # type: ignore
return "No results found."

function_executor = InProcessFunctionExecutor(functions=[search])

product_manager = ChatCompletionAgent(
name="ProductManager",
description="A product manager that performs research and comes up with specs.",
Expand All @@ -95,7 +110,7 @@ def search(query: Annotated[str, "The search query."]) -> Annotated[str, "The se
SystemMessage("You can use the search tool to find information on the web."),
],
model_client=OpenAI(model="gpt-4-turbo"),
function_executor=function_executor,
tools=[SearchTool()],
)

planner = ChatCompletionAgent(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ disallow_untyped_decorators = true
disallow_any_unimported = true

[tool.pyright]
include = ["src", "tests"]
include = ["src", "tests", "examples"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false
Expand Down
58 changes: 33 additions & 25 deletions src/agnext/chat/agents/chat_completion_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import Any, Coroutine, Dict, List, Mapping, Tuple
from typing import Any, Coroutine, Dict, List, Mapping, Sequence, Tuple

from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import (
Expand All @@ -12,13 +12,13 @@
TextMessage,
)
from agnext.chat.utils import convert_messages_to_llm_messages
from agnext.components.function_executor import FunctionExecutor
from agnext.components.models import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.components.types import (
from agnext.components import (
FunctionCall,
FunctionSignature,
TypeRoutedAgent,
message_handler,
)
from agnext.components.models import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage
from agnext.components.tools import Tool
from agnext.core import AgentRuntime, CancellationToken


Expand All @@ -30,13 +30,13 @@ def __init__(
runtime: AgentRuntime,
system_messages: List[SystemMessage],
model_client: ModelClient,
function_executor: FunctionExecutor | None = None,
tools: Sequence[Tool] = [],
) -> None:
super().__init__(name, description, runtime)
self._system_messages = system_messages
self._client = model_client
self._chat_messages: List[Message] = []
self._function_executor = function_executor
self._tools = tools

@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
Expand All @@ -52,23 +52,18 @@ async def on_reset(self, message: Reset, cancellation_token: CancellationToken)
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
# Get function signatures.
function_signatures: List[FunctionSignature] = (
[] if self._function_executor is None else list(self._function_executor.function_signatures)
)

# Get a response from the model.
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
functions=function_signatures,
tools=self._tools,
json_output=message.response_format == ResponseFormat.json_object,
)

# If the agent has function executor, and the response is a list of
# tool calls, iterate with itself until we get a response that is not a
# list of tool calls.
while (
self._function_executor is not None
len(self._tools) > 0
and isinstance(response.content, list)
and all(isinstance(x, FunctionCall) for x in response.content)
):
Expand All @@ -81,7 +76,7 @@ async def on_respond_now(
# Make an assistant message from the response.
response = await self._client.create(
self._system_messages + convert_messages_to_llm_messages(self._chat_messages, self.name),
functions=function_signatures,
tools=self._tools,
json_output=message.response_format == ResponseFormat.json_object,
)

Expand All @@ -105,8 +100,8 @@ async def on_respond_now(
async def on_tool_call_message(
self, message: FunctionCallMessage, cancellation_token: CancellationToken
) -> FunctionExecutionResultMessage:
if self._function_executor is None:
raise ValueError("Function executor is not set.")
if len(self._tools) == 0:
raise ValueError("No tools available")

# Add a tool call message.
self._chat_messages.append(message)
Expand All @@ -127,7 +122,9 @@ async def on_tool_call_message(
)
continue
# Execute the function.
future = self.execute_function(function_call.name, arguments, function_call.id)
future = self.execute_function(
function_call.name, arguments, function_call.id, cancellation_token=cancellation_token
)
# Append the async result.
execution_futures.append(future)
if execution_futures:
Expand All @@ -146,14 +143,25 @@ async def on_tool_call_message(
# Return the results.
return tool_call_result_msg

async def execute_function(self, name: str, args: Dict[str, Any], call_id: str) -> Tuple[str, str]:
if self._function_executor is None:
raise ValueError("Function executor is not set.")
async def execute_function(
self, name: str, args: Dict[str, Any], call_id: str, cancellation_token: CancellationToken
) -> Tuple[str, str]:
# Find tool
tool = next((t for t in self._tools if t.name == name), None)
if tool is None:
raise ValueError(f"Tool {name} not found.")
try:
result = await self._function_executor.execute_function(name, args)
result = await tool.run_json(args, cancellation_token)
result_json_or_str = result.model_dump()
if isinstance(result, dict):
result_str = json.dumps(result_json_or_str)
elif isinstance(result_json_or_str, str):
result_str = result_json_or_str
else:
raise ValueError(f"Unexpected result type: {type(result)}")
except Exception as e:
result = f"Error: {str(e)}"
return (result, call_id)
result_str = f"Error: {str(e)}"
return (result_str, call_id)

def save_state(self) -> Mapping[str, Any]:
return {
Expand Down
2 changes: 1 addition & 1 deletion src/agnext/chat/agents/oai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Reset, RespondNow, ResponseFormat, TextMessage
from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import AgentRuntime, CancellationToken


Expand Down
2 changes: 1 addition & 1 deletion src/agnext/chat/patterns/group_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Protocol, Sequence

from ...components.type_routed_agent import TypeRoutedAgent, message_handler
from ...components import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import Reset, RespondNow, TextMessage
Expand Down
2 changes: 1 addition & 1 deletion src/agnext/chat/patterns/orchestrator_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Sequence, Tuple

from ...components.type_routed_agent import TypeRoutedAgent, message_handler
from ...components import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
Expand Down
3 changes: 1 addition & 2 deletions src/agnext/chat/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from enum import Enum
from typing import List, Union

from agnext.components.image import Image
from agnext.components import FunctionCall, Image
from agnext.components.models import FunctionExecutionResultMessage
from agnext.components.types import FunctionCall


@dataclass(kw_only=True)
Expand Down
6 changes: 6 additions & 0 deletions src/agnext/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""
The :mod:`agnext.components` module provides building blocks for creating single agents
"""

from ._image import Image
from ._type_routed_agent import TypeRoutedAgent, message_handler
from ._types import FunctionCall

__all__ = ["Image", "TypeRoutedAgent", "message_handler", "FunctionCall"]
76 changes: 63 additions & 13 deletions src/agnext/components/_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
from logging import getLogger
from typing import (
Annotated,
Any,
Callable,
Dict,
Expand All @@ -15,10 +16,13 @@
Type,
TypeVar,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel, Field
from typing_extensions import Annotated, Literal
from pydantic import BaseModel, Field, create_model # type: ignore
from pydantic_core import PydanticUndefined
from typing_extensions import Literal

from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema

Expand Down Expand Up @@ -125,6 +129,18 @@ class ToolFunction(BaseModel):
function: Annotated[Function, Field(description="Function under tool")]


def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else:
return k


def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) -> Dict[str, Any]:
"""Get a JSON schema for a parameter as defined by the OpenAI API
Expand All @@ -137,17 +153,6 @@ def get_parameter_json_schema(k: str, v: Any, default_values: Dict[str, Any]) ->
A Pydanitc model for the parameter
"""

def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
retval = v.__metadata__[0]
if isinstance(retval, str):
return retval
else:
raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else:
return k

schema = type2schema(v)
if k in default_values:
dv = default_values[k]
Expand Down Expand Up @@ -297,3 +302,48 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet
)

return model_dump(function)


def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:
"""Normalize typing.Annotated types to the inner type."""
if get_origin(type_hint) is Annotated:
# Extract the inner type from Annotated
return get_args(type_hint)[0] # type: ignore
return type_hint


def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
fields: List[tuple[str, Any]] = []
for name, param in sig.parameters.items():
# This is handled externally
if name == "cancellation_token":
continue

if param.annotation is inspect.Parameter.empty:
raise ValueError("No annotation")

type = normalize_annotated_type(param.annotation)
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined

fields.append((name, (type, Field(default=default_value, description=description))))

return create_model(name, *fields)


def return_value_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
if issubclass(BaseModel, sig.return_annotation):
return sig.return_annotation # type: ignore

fields: List[tuple[str, Any]] = []
for name, param in sig.return_annotation:
if param.annotation is inspect.Parameter.empty:
raise ValueError("No annotation")

type = normalize_annotated_type(param.annotation)
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined

fields.append((name, (type, Field(default=default_value, description=description))))

return create_model(name, *fields)
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit ab420cc

Please sign in to comment.