diff --git a/examples/patterns.py b/examples/patterns.py index 1cb5707cc49..32f889ffddb 100644 --- a/examples/patterns.py +++ b/examples/patterns.py @@ -3,14 +3,16 @@ import openai from agnext.agent_components.models_clients.openai_client import OpenAI -from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime +from agnext.application_components.single_threaded_agent_runtime import ( + SingleThreadedAgentRuntime, +) from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent from agnext.chat.messages import ChatMessage from agnext.chat.patterns.group_chat import GroupChat from agnext.chat.patterns.orchestrator import Orchestrator -async def group_chat() -> None: +async def group_chat(message: str) -> None: runtime = SingleThreadedAgentRuntime() joe_oai_assistant = openai.beta.assistants.create( @@ -44,14 +46,14 @@ async def group_chat() -> None: ) chat = GroupChat( - "chat_room", + "Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, ) - response = runtime.send_message(ChatMessage(body="Run a show!", sender="external"), chat) + response = runtime.send_message(ChatMessage(body=message, sender="host"), chat) while not response.done(): await runtime.process_next() @@ -59,7 +61,7 @@ async def group_chat() -> None: print((await response).body) # type: ignore -async def orchestrator() -> None: +async def orchestrator(message: str) -> None: runtime = SingleThreadedAgentRuntime() developer_oai_assistant = openai.beta.assistants.create( @@ -93,8 +95,8 @@ async def orchestrator() -> None: ) chat = Orchestrator( - "Team", - "A software development team.", + "Manager", + "A software development team manager.", runtime, [developer, product_manager], model_client=OpenAI(model="gpt-3.5-turbo"), @@ -102,7 +104,7 @@ async def orchestrator() -> None: response = runtime.send_message( ChatMessage( - body="Write a simple FastAPI webapp for showing the current time.", + body=message, sender="customer", ), chat, @@ -122,11 +124,12 @@ async def orchestrator() -> None: choices=chocies, help="The pattern to demo.", ) + parser.add_argument("--message", help="The message to send.") args = parser.parse_args() if args.pattern == "group_chat": - asyncio.run(group_chat()) + asyncio.run(group_chat(args.message)) elif args.pattern == "orchestrator": - asyncio.run(orchestrator()) + asyncio.run(orchestrator(args.message)) else: raise ValueError(f"Invalid pattern: {args.pattern}") diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index f837b6f2750..4efd490f210 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -1,6 +1,8 @@ from typing import List, Sequence +from ...agent_components.type_routed_agent import message_handler from ...core.agent_runtime import AgentRuntime +from ...core.cancellation_token import CancellationToken from ..agents.base import BaseChatAgent from ..messages import ChatMessage @@ -19,7 +21,13 @@ def __init__( self._num_rounds = num_rounds self._history: List[ChatMessage] = [] - async def on_chat_message(self, message: ChatMessage) -> ChatMessage: + @message_handler(ChatMessage) + async def on_chat_message( + self, + message: ChatMessage, + require_response: bool, + cancellation_token: CancellationToken, + ) -> ChatMessage | None: if message.reset: # Reset the history. self._history = [] diff --git a/src/agnext/chat/patterns/orchestrator.py b/src/agnext/chat/patterns/orchestrator.py index 0c5e4915d1e..0798da8260c 100644 --- a/src/agnext/chat/patterns/orchestrator.py +++ b/src/agnext/chat/patterns/orchestrator.py @@ -1,12 +1,11 @@ import json from typing import Any, List, Sequence, Tuple -from agnext.core.agent_runtime import AgentRuntime -from agnext.core.cancellation_token import CancellationToken - from ...agent_components.model_client import ModelClient from ...agent_components.type_routed_agent import message_handler from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage +from ...core.agent_runtime import AgentRuntime +from ...core.cancellation_token import CancellationToken from ..agents.base import BaseChatAgent from ..messages import ChatMessage @@ -33,8 +32,11 @@ def __init__( @message_handler(ChatMessage) async def on_chat_message( - self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken - ) -> ChatMessage: + self, + message: ChatMessage, + require_response: bool, + cancellation_token: CancellationToken, + ) -> ChatMessage | None: # A task is received. task = message.body diff --git a/src/agnext/chat/utils.py b/src/agnext/chat/utils.py index b554e2d2fe9..4f77c77be98 100644 --- a/src/agnext/chat/utils.py +++ b/src/agnext/chat/utils.py @@ -1,8 +1,9 @@ from typing import List, Optional, Union +from typing_extensions import Literal + from agnext.agent_components.types import AssistantMessage, LLMMessage, UserMessage from agnext.chat.types import FunctionCallMessage, Message, MultiModalMessage, TextMessage -from typing_extensions import Literal def convert_content_message_to_assistant_message(