From 5b01f69b5877723cba565d9e44b818855f3f30d3 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 18 Jun 2024 14:53:18 -0400 Subject: [PATCH] Move agent creation into the runtime (#89) * Move agent creation into the runtime * update doc * add test * Remove limitation of subscriptions being same across namespaces * constrain agent types to namespaces --- docs/src/guides/type-routed-agent.md | 5 +- examples/assistant.py | 55 ++-- examples/chat_room.py | 68 +++-- examples/chess_game.py | 78 ++--- examples/coder_reviewer.py | 89 +++--- examples/illustrator_critics.py | 112 ++++---- examples/{inner_outter.py => inner_outer.py} | 16 +- examples/orchestrator.py | 92 +++--- examples/software_consultancy.py | 267 +++++++++--------- examples/utils.py | 8 +- .../_single_threaded_agent_runtime.py | 255 ++++++++++++----- .../chat/agents/chat_completion_agent.py | 10 +- .../chat/agents/image_generation_agent.py | 6 +- src/agnext/chat/agents/oai_assistant.py | 6 +- src/agnext/chat/agents/user_proxy.py | 6 +- .../chat/patterns/group_chat_manager.py | 3 +- src/agnext/chat/patterns/orchestrator_chat.py | 65 ++--- src/agnext/components/_type_routed_agent.py | 6 +- src/agnext/core/__init__.py | 3 +- src/agnext/core/_agent_metadata.py | 1 + src/agnext/core/_agent_props.py | 6 +- src/agnext/core/_agent_proxy.py | 15 +- src/agnext/core/_agent_runtime.py | 119 ++++++-- src/agnext/core/_base_agent.py | 55 +++- src/agnext/core/intervention.py | 16 +- tests/test_cancellation.py | 73 ++--- tests/test_intervention.py | 70 ++--- tests/test_runtime.py | 72 ++++- tests/test_state.py | 16 +- tests/test_utils/__init__.py | 20 ++ 30 files changed, 970 insertions(+), 643 deletions(-) rename examples/{inner_outter.py => inner_outer.py} (79%) create mode 100644 tests/test_utils/__init__.py diff --git a/docs/src/guides/type-routed-agent.md b/docs/src/guides/type-routed-agent.md index 8c9ba64bbe9..ed6dc1648b7 100644 --- a/docs/src/guides/type-routed-agent.md +++ b/docs/src/guides/type-routed-agent.md @@ -29,14 +29,15 @@ from agnext.core import AgentRuntime, CancellationToken class MyAgent(TypeRoutedAgent): - def __init__(self, name: str, runtime: AgentRuntime): - super().__init__(name, "I am a demo agent", runtime) + def __init__(self): + super().__init__(description="I am a demo agent") self._received_count = 0 @message_handler() async def on_text_message( self, message: TextMessage | MultiModalMessage, cancellation_token: CancellationToken ) -> None: + self._received_count += 1 await self._publish_message( TextMessage( content=f"I received a message from {message.source}. Message received #{self._received_count}", diff --git a/examples/assistant.py b/examples/assistant.py index a3b1afebbed..989a57a9284 100644 --- a/examples/assistant.py +++ b/examples/assistant.py @@ -16,7 +16,7 @@ from agnext.chat.patterns.group_chat_manager import GroupChatManager from agnext.chat.types import PublishNow, TextMessage from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentRuntime, CancellationToken +from agnext.core import AgentId, AgentRuntime, CancellationToken from openai import AsyncAssistantEventHandler from openai.types.beta.thread import ToolResources from openai.types.beta.threads import Message, Text, TextDelta @@ -29,17 +29,13 @@ class UserProxyAgent(TypeRoutedAgent): # type: ignore def __init__( # type: ignore self, - name: str, - runtime: AgentRuntime, # type: ignore client: openai.AsyncClient, # type: ignore assistant_id: str, thread_id: str, vector_store_id: str, ) -> None: # type: ignore super().__init__( - name=name, description="A human user", - runtime=runtime, ) # type: ignore self._client = client self._assistant_id = assistant_id @@ -166,7 +162,7 @@ async def on_message_done(self, message: Message) -> None: print("\n".join(citations)) -def assistant_chat(runtime: AgentRuntime) -> UserProxyAgent: # type: ignore +def assistant_chat(runtime: AgentRuntime) -> AgentId: oai_assistant = openai.beta.assistants.create( model="gpt-4-turbo", description="An AI assistant that helps with everyday tasks.", @@ -177,30 +173,35 @@ def assistant_chat(runtime: AgentRuntime) -> UserProxyAgent: # type: ignore thread = openai.beta.threads.create( tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, ) - assistant = OpenAIAssistantAgent( - name="Assistant", - description="An AI assistant that helps with everyday tasks.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=oai_assistant.id, - thread_id=thread.id, - assistant_event_handler_factory=lambda: EventHandler(), + assistant = runtime.register_and_get( + "Assistant", + lambda: OpenAIAssistantAgent( + description="An AI assistant that helps with everyday tasks.", + client=openai.AsyncClient(), + assistant_id=oai_assistant.id, + thread_id=thread.id, + assistant_event_handler_factory=lambda: EventHandler(), + ), ) - user = UserProxyAgent( - name="User", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=oai_assistant.id, - thread_id=thread.id, - vector_store_id=vector_store.id, + + user = runtime.register_and_get( + "User", + lambda: UserProxyAgent( + client=openai.AsyncClient(), + assistant_id=oai_assistant.id, + thread_id=thread.id, + vector_store_id=vector_store.id, + ), ) # Create a group chat manager to facilitate a turn-based conversation. - _ = GroupChatManager( - name="GroupChatManager", - description="A group chat manager.", - runtime=runtime, - memory=BufferedChatMemory(buffer_size=10), - participants=[assistant.id, user.id], + runtime.register( + "GroupChatManager", + lambda: GroupChatManager( + description="A group chat manager.", + runtime=runtime, + memory=BufferedChatMemory(buffer_size=10), + participants=[assistant, user], + ), ) return user diff --git a/examples/chat_room.py b/examples/chat_room.py index 06eea51e393..e3cd584a274 100644 --- a/examples/chat_room.py +++ b/examples/chat_room.py @@ -23,12 +23,11 @@ def __init__( # type: ignore self, name: str, description: str, - runtime: AgentRuntime, # type: ignore background_story: str, memory: ChatMemory, # type: ignore model_client: ChatCompletionClient, # type: ignore ) -> None: # type: ignore - super().__init__(name, description, runtime) + super().__init__(description) system_prompt = f"""Your name is {name}. Your background story is: {background_story} @@ -86,40 +85,47 @@ async def on_chat_room_message(self, message: TextMessage, cancellation_token: C # Define a chat room with participants -- the runtime is the chat room. def chat_room(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore - _ = ChatRoomUserAgent( - name="User", - description="The user in the chat room.", - runtime=runtime, - app=app, + runtime.register( + "User", + lambda: ChatRoomUserAgent( + description="The user in the chat room.", + app=app, + ), ) - alice = ChatRoomAgent( - name="Alice", - description="Alice in the chat room.", - runtime=runtime, - background_story="Alice is a software engineer who loves to code.", - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), # type: ignore + alice = runtime.register_and_get_proxy( + "Alice", + lambda rt, id: ChatRoomAgent( + name=id.name, + description="Alice in the chat room.", + background_story="Alice is a software engineer who loves to code.", + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), # type: ignore + ), ) - bob = ChatRoomAgent( - name="Bob", - description="Bob in the chat room.", - runtime=runtime, - background_story="Bob is a data scientist who loves to analyze data.", - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), # type: ignore + bob = runtime.register_and_get_proxy( + "Bob", + lambda rt, id: ChatRoomAgent( + name=id.name, + description="Bob in the chat room.", + background_story="Bob is a data scientist who loves to analyze data.", + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), # type: ignore + ), ) - charlie = ChatRoomAgent( - name="Charlie", - description="Charlie in the chat room.", - runtime=runtime, - background_story="Charlie is a designer who loves to create art.", - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), # type: ignore + charlie = runtime.register_and_get_proxy( + "Charlie", + lambda rt, id: ChatRoomAgent( + name=id.name, + description="Charlie in the chat room.", + background_story="Charlie is a designer who loves to create art.", + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), # type: ignore + ), ) app.welcoming_notice = f"""Welcome to the chat room demo with the following participants: -1. 👧 {alice.metadata['name']}: {alice.metadata['description']} -2. 👱🏼‍♂️ {bob.metadata['name']}: {bob.metadata['description']} -3. 👨🏾‍🦳 {charlie.metadata['name']}: {charlie.metadata['description']} +1. 👧 {alice.id.name}: {alice.metadata['description']} +2. 👱🏼‍♂️ {bob.id.name}: {bob.metadata['description']} +3. 👨🏾‍🦳 {charlie.id.name}: {charlie.metadata['description']} Each participant decides on its own whether to respond to the latest message. diff --git a/examples/chess_game.py b/examples/chess_game.py index 230d46d0ef7..0e6b52dda79 100644 --- a/examples/chess_game.py +++ b/examples/chess_game.py @@ -150,46 +150,50 @@ def get_board_text() -> Annotated[str, "The current board state"]: ), ] - black = ChatCompletionAgent( - name="PlayerBlack", - description="Player playing black.", - runtime=runtime, - system_messages=[ - SystemMessage( - content="You are a chess player and you play as black. " - "Use get_legal_moves() to get list of legal moves. " - "Use get_board() to get the current board state. " - "Think about your strategy and call make_move(thinking, move) to make a move." - ), - ], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), - tools=black_tools, + black = runtime.register_and_get( + "PlayerBlack", + lambda: ChatCompletionAgent( + description="Player playing black.", + system_messages=[ + SystemMessage( + content="You are a chess player and you play as black. " + "Use get_legal_moves() to get list of legal moves. " + "Use get_board() to get the current board state. " + "Think about your strategy and call make_move(thinking, move) to make a move." + ), + ], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + tools=black_tools, + ), ) - white = ChatCompletionAgent( - name="PlayerWhite", - description="Player playing white.", - runtime=runtime, - system_messages=[ - SystemMessage( - content="You are a chess player and you play as white. " - "Use get_legal_moves() to get list of legal moves. " - "Use get_board() to get the current board state. " - "Think about your strategy and call make_move(thinking, move) to make a move." - ), - ], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), - tools=white_tools, + white = runtime.register_and_get( + "PlayerWhite", + lambda: ChatCompletionAgent( + description="Player playing white.", + system_messages=[ + SystemMessage( + content="You are a chess player and you play as white. " + "Use get_legal_moves() to get list of legal moves. " + "Use get_board() to get the current board state. " + "Think about your strategy and call make_move(thinking, move) to make a move." + ), + ], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + tools=white_tools, + ), ) # Create a group chat manager for the chess game to orchestrate a turn-based # conversation between the two agents. - _ = GroupChatManager( - name="ChessGame", - description="A chess game between two agents.", - runtime=runtime, - memory=BufferedChatMemory(buffer_size=10), - participants=[white.id, black.id], # white goes first + runtime.register( + "ChessGame", + lambda: GroupChatManager( + description="A chess game between two agents.", + runtime=runtime, + memory=BufferedChatMemory(buffer_size=10), + participants=[white, black], # white goes first + ), ) @@ -197,7 +201,7 @@ async def main() -> None: runtime = SingleThreadedAgentRuntime() chess_game(runtime) # Publish an initial message to trigger the group chat manager to start orchestration. - runtime.publish_message(TextMessage(content="Game started.", source="System")) + runtime.publish_message(TextMessage(content="Game started.", source="System"), namespace="default") while True: await runtime.process_next() await asyncio.sleep(1) diff --git a/examples/coder_reviewer.py b/examples/coder_reviewer.py index 89e40711115..649792a4f0f 100644 --- a/examples/coder_reviewer.py +++ b/examples/coder_reviewer.py @@ -17,51 +17,56 @@ def coder_reviewer(runtime: AgentRuntime, app: TextualChatApp) -> None: - _ = TextualUserAgent( - name="Human", - description="A human user that provides a problem statement.", - runtime=runtime, - app=app, + runtime.register( + "Human", + lambda: TextualUserAgent( + description="A human user that provides a problem statement.", + app=app, + ), ) - coder = ChatCompletionAgent( - name="Coder", - description="An agent that writes code", - runtime=runtime, - system_messages=[ - SystemMessage( - "You are a coder. You can write code to solve problems.\n" - "Work with the reviewer to improve your code." - ) - ], - model_client=OpenAI(model="gpt-4-turbo"), - memory=BufferedChatMemory(buffer_size=10), + coder = runtime.register_and_get_proxy( + "Coder", + lambda: ChatCompletionAgent( + description="An agent that writes code", + system_messages=[ + SystemMessage( + "You are a coder. You can write code to solve problems.\n" + "Work with the reviewer to improve your code." + ) + ], + model_client=OpenAI(model="gpt-4-turbo"), + memory=BufferedChatMemory(buffer_size=10), + ), ) - reviewer = ChatCompletionAgent( - name="Reviewer", - description="An agent that reviews code", - runtime=runtime, - system_messages=[ - SystemMessage( - "You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n" - "Respond using the following format:\n" - "Code Review:\n" - "Correctness: \n" - "Efficiency: \n" - "Safety: \n" - "Approval: \n" - "Suggested Changes: " - ) - ], - model_client=OpenAI(model="gpt-4-turbo"), - memory=BufferedChatMemory(buffer_size=10), + reviewer = runtime.register_and_get_proxy( + "Reviewer", + lambda: ChatCompletionAgent( + description="An agent that reviews code", + system_messages=[ + SystemMessage( + "You are a code reviewer. You focus on correctness, efficiency and safety of the code.\n" + "Respond using the following format:\n" + "Code Review:\n" + "Correctness: \n" + "Efficiency: \n" + "Safety: \n" + "Approval: \n" + "Suggested Changes: " + ) + ], + model_client=OpenAI(model="gpt-4-turbo"), + memory=BufferedChatMemory(buffer_size=10), + ), ) - _ = GroupChatManager( - name="Manager", - description="A manager that orchestrates a back-and-forth converation between a coder and a reviewer.", - runtime=runtime, - participants=[coder.id, reviewer.id], # The order of the participants indicates the order of speaking. - memory=BufferedChatMemory(buffer_size=10), - termination_word="APPROVE", + runtime.register( + "Manager", + lambda: GroupChatManager( + description="A manager that orchestrates a back-and-forth converation between a coder and a reviewer.", + runtime=runtime, + participants=[coder.id, reviewer.id], # The order of the participants indicates the order of speaking. + memory=BufferedChatMemory(buffer_size=10), + termination_word="APPROVE", + ), ) app.welcoming_notice = f"""Welcome to the coder-reviewer demo with the following roles: 1. 🤖 {coder.metadata['name']}: {coder.metadata['description']} diff --git a/examples/illustrator_critics.py b/examples/illustrator_critics.py index 0de48efe48b..4cf0976f149 100644 --- a/examples/illustrator_critics.py +++ b/examples/illustrator_critics.py @@ -17,63 +17,69 @@ def illustrator_critics(runtime: AgentRuntime, app: TextualChatApp) -> str: # type: ignore - _ = TextualUserAgent( - name="User", - description="A user looking for illustration.", - runtime=runtime, - app=app, + runtime.register( + "User", + lambda: TextualUserAgent( + description="A user looking for illustration.", + app=app, + ), ) - descriptor = ChatCompletionAgent( - name="Descriptor", - description="An AI agent that provides a description of the image.", - runtime=runtime, - system_messages=[ - SystemMessage( - "You create short description for image. \n" - "In this conversation, you will be given either: \n" - "1. Request for new image. \n" - "2. Feedback on some image created. \n" - "In both cases, you will provide a description of a new image to be created. \n" - "Only provide the description of the new image and nothing else. \n" - "Be succinct and precise." - ), - ], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo", max_tokens=500), + descriptor = runtime.register_and_get_proxy( + "Descriptor", + lambda: ChatCompletionAgent( + description="An AI agent that provides a description of the image.", + system_messages=[ + SystemMessage( + "You create short description for image. \n" + "In this conversation, you will be given either: \n" + "1. Request for new image. \n" + "2. Feedback on some image created. \n" + "In both cases, you will provide a description of a new image to be created. \n" + "Only provide the description of the new image and nothing else. \n" + "Be succinct and precise." + ), + ], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo", max_tokens=500), + ), ) - illustrator = ImageGenerationAgent( - name="Illustrator", - description="An AI agent that generates images.", - runtime=runtime, - client=openai.AsyncOpenAI(), - model="dall-e-3", - memory=BufferedChatMemory(buffer_size=1), + illustrator = runtime.register_and_get_proxy( + "Illustrator", + lambda: ImageGenerationAgent( + description="An AI agent that generates images.", + client=openai.AsyncOpenAI(), + model="dall-e-3", + memory=BufferedChatMemory(buffer_size=1), + ), ) - critic = ChatCompletionAgent( - name="Critic", - description="An AI agent that provides feedback on images given user's requirements.", - runtime=runtime, - system_messages=[ - SystemMessage( - "You are an expert in image understanding. \n" - "In this conversation, you will judge an image given the description and provide feedback. \n" - "Pay attention to the details like the spelling of words and number of objects. \n" - "Use the following format in your response: \n" - "Number of each object type in the image: : 1, : 2, ...\n" - "Feedback: \n" - "Approval: \n" - ), - ], - memory=BufferedChatMemory(buffer_size=2), - model_client=OpenAI(model="gpt-4-turbo"), + critic = runtime.register_and_get_proxy( + "Critic", + lambda: ChatCompletionAgent( + description="An AI agent that provides feedback on images given user's requirements.", + system_messages=[ + SystemMessage( + "You are an expert in image understanding. \n" + "In this conversation, you will judge an image given the description and provide feedback. \n" + "Pay attention to the details like the spelling of words and number of objects. \n" + "Use the following format in your response: \n" + "Number of each object type in the image: : 1, : 2, ...\n" + "Feedback: \n" + "Approval: \n" + ), + ], + memory=BufferedChatMemory(buffer_size=2), + model_client=OpenAI(model="gpt-4-turbo"), + ), ) - _ = GroupChatManager( - name="GroupChatManager", - description="A chat manager that handles group chat.", - runtime=runtime, - memory=BufferedChatMemory(buffer_size=5), - participants=[illustrator.id, critic.id, descriptor.id], - termination_word="APPROVE", + runtime.register( + "GroupChatManager", + lambda: GroupChatManager( + description="A chat manager that handles group chat.", + runtime=runtime, + memory=BufferedChatMemory(buffer_size=5), + participants=[illustrator.id, critic.id, descriptor.id], + termination_word="APPROVE", + ), ) app.welcoming_notice = f"""You are now in a group chat with the following agents: diff --git a/examples/inner_outter.py b/examples/inner_outer.py similarity index 79% rename from examples/inner_outter.py rename to examples/inner_outer.py index 5de492dee6e..4bc13323735 100644 --- a/examples/inner_outter.py +++ b/examples/inner_outer.py @@ -5,7 +5,7 @@ from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentProxy, AgentRuntime, CancellationToken +from agnext.core import AgentId, CancellationToken @dataclass @@ -15,8 +15,8 @@ class MessageType: class Inner(TypeRoutedAgent): # type: ignore - def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore - super().__init__(name, "The inner agent", runtime) + def __init__(self) -> None: # type: ignore + super().__init__("The inner agent") @message_handler() # type: ignore async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore @@ -24,13 +24,13 @@ async def on_new_message(self, message: MessageType, cancellation_token: Cancell class Outer(TypeRoutedAgent): # type: ignore - def __init__(self, name: str, runtime: AgentRuntime, inner: AgentProxy) -> None: # type: ignore - super().__init__(name, "The outter agent", runtime) + def __init__(self, inner: AgentId) -> None: # type: ignore + super().__init__("The outer agent") self._inner = inner @message_handler() # type: ignore async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore - inner_response = self._send_message(message, self._inner.id) + inner_response = self._send_message(message, self._inner) inner_message = await inner_response assert isinstance(inner_message, MessageType) return MessageType(body=f"Outer: {inner_message.body}", sender=self.metadata["name"]) @@ -38,8 +38,8 @@ async def on_new_message(self, message: MessageType, cancellation_token: Cancell async def main() -> None: runtime = SingleThreadedAgentRuntime() - inner = Inner("inner", runtime) - outer = Outer("outer", runtime, AgentProxy(inner, runtime)) + inner = runtime.register_and_get("inner", Inner) + outer = runtime.register_and_get("outer", lambda: Outer(inner)) response = runtime.send_message(MessageType(body="Hello", sender="external"), outer) while not response.done(): diff --git a/examples/orchestrator.py b/examples/orchestrator.py index 43111d0b12f..77ffa3537e3 100644 --- a/examples/orchestrator.py +++ b/examples/orchestrator.py @@ -79,13 +79,14 @@ async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | N def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ignore - developer = ChatCompletionAgent( - name="Developer", - description="A developer that writes code.", - runtime=runtime, - system_messages=[SystemMessage("You are a Python developer.")], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), + developer = runtime.register_and_get_proxy( + "Developer", + lambda: ChatCompletionAgent( + description="A developer that writes code.", + system_messages=[SystemMessage("You are a Python developer.")], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + ), ) tester_oai_assistant = openai.beta.assistants.create( @@ -94,50 +95,55 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig instructions="You are a software tester that runs test cases and reports results.", ) tester_oai_thread = openai.beta.threads.create() - tester = OpenAIAssistantAgent( - name="Tester", - description="A software tester that runs test cases and reports results.", - runtime=runtime, - client=openai.AsyncClient(), - assistant_id=tester_oai_assistant.id, - thread_id=tester_oai_thread.id, + tester = runtime.register_and_get_proxy( + "Tester", + lambda: OpenAIAssistantAgent( + description="A software tester that runs test cases and reports results.", + client=openai.AsyncClient(), + assistant_id=tester_oai_assistant.id, + thread_id=tester_oai_thread.id, + ), ) - product_manager = ChatCompletionAgent( - name="ProductManager", - description="A product manager that performs research and comes up with specs.", - runtime=runtime, - system_messages=[ - SystemMessage("You are a product manager good at translating customer needs into software specifications."), - SystemMessage("You can use the search tool to find information on the web."), - ], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), - tools=[SearchTool()], + product_manager = runtime.register_and_get_proxy( + "ProductManager", + lambda: ChatCompletionAgent( + description="A product manager that performs research and comes up with specs.", + system_messages=[ + SystemMessage( + "You are a product manager good at translating customer needs into software specifications." + ), + SystemMessage("You can use the search tool to find information on the web."), + ], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + tools=[SearchTool()], + ), ) - planner = ChatCompletionAgent( - name="Planner", - description="A planner that organizes and schedules tasks.", - runtime=runtime, - system_messages=[SystemMessage("You are a planner of complex tasks.")], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), + planner = runtime.register_and_get_proxy( + "Planner", + lambda: ChatCompletionAgent( + description="A planner that organizes and schedules tasks.", + system_messages=[SystemMessage("You are a planner of complex tasks.")], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + ), ) - orchestrator = ChatCompletionAgent( - name="Orchestrator", - description="An orchestrator that coordinates the team.", - runtime=runtime, - system_messages=[ - SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.") - ], - memory=BufferedChatMemory(buffer_size=10), - model_client=OpenAI(model="gpt-4-turbo"), + orchestrator = runtime.register_and_get_proxy( + "Orchestrator", + lambda: ChatCompletionAgent( + description="An orchestrator that coordinates the team.", + system_messages=[ + SystemMessage("You are an orchestrator that coordinates the team to complete a complex task.") + ], + memory=BufferedChatMemory(buffer_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + ), ) return OrchestratorChat( - "OrchestratorChat", "A software development team.", runtime, orchestrator=orchestrator.id, @@ -149,7 +155,7 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig async def run(message: str, user: str, scenario: Callable[[AgentRuntime], OrchestratorChat]) -> None: # type: ignore runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) chat = scenario(runtime) - response = runtime.send_message(TextMessage(content=message, source=user), chat) + response = runtime.send_message(TextMessage(content=message, source=user), chat.id) while not response.done(): await runtime.process_next() print((await response).content) # type: ignore diff --git a/examples/software_consultancy.py b/examples/software_consultancy.py index 396d33a8e43..c247d6a48ed 100644 --- a/examples/software_consultancy.py +++ b/examples/software_consultancy.py @@ -103,140 +103,149 @@ async def create_image( def software_consultancy(runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore - user_agent = TextualUserAgent( - name="Customer", - description="A customer looking for help.", - runtime=runtime, - app=app, + user_agent = runtime.register_and_get( + "Customer", + lambda: TextualUserAgent( + description="A customer looking for help.", + app=app, + ), ) - developer = ChatCompletionAgent( - name="Developer", - description="A Python software developer.", - runtime=runtime, - system_messages=[ - SystemMessage( - "Your are a Python developer. \n" - "You can read, write, and execute code. \n" - "You can browse files and directories. \n" - "You can also browse the web for documentation. \n" - "You are entering a work session with the customer, product manager, UX designer, and illustrator. \n" - "When you are given a task, you should immediately start working on it. \n" - "Be concise and deliver now." - ) - ], - model_client=OpenAI(model="gpt-4-turbo"), - memory=HeadAndTailChatMemory(head_size=1, tail_size=10), - tools=[ - FunctionTool( - write_file, - name="write_file", - description="Write code to a file.", - ), - FunctionTool( - read_file, - name="read_file", - description="Read code from a file.", - ), - FunctionTool( - execute_command, - name="execute_command", - description="Execute a unix shell command.", - ), - FunctionTool(list_files, name="list_files", description="List files in a directory."), - FunctionTool(browse_web, name="browse_web", description="Browse a web page."), - ], - tool_approver=user_agent, + developer = runtime.register_and_get( + "Developer", + lambda: ChatCompletionAgent( + description="A Python software developer.", + system_messages=[ + SystemMessage( + "Your are a Python developer. \n" + "You can read, write, and execute code. \n" + "You can browse files and directories. \n" + "You can also browse the web for documentation. \n" + "You are entering a work session with the customer, product manager, UX designer, and illustrator. \n" + "When you are given a task, you should immediately start working on it. \n" + "Be concise and deliver now." + ) + ], + model_client=OpenAI(model="gpt-4-turbo"), + memory=HeadAndTailChatMemory(head_size=1, tail_size=10), + tools=[ + FunctionTool( + write_file, + name="write_file", + description="Write code to a file.", + ), + FunctionTool( + read_file, + name="read_file", + description="Read code from a file.", + ), + FunctionTool( + execute_command, + name="execute_command", + description="Execute a unix shell command.", + ), + FunctionTool(list_files, name="list_files", description="List files in a directory."), + FunctionTool(browse_web, name="browse_web", description="Browse a web page."), + ], + tool_approver=user_agent, + ), ) - product_manager = ChatCompletionAgent( - name="ProductManager", - description="A product manager. " - "Responsible for interfacing with the customer, planning and managing the project. ", - runtime=runtime, - system_messages=[ - SystemMessage( - "You are a product manager. \n" - "You can browse files and directories. \n" - "You are entering a work session with the customer, developer, UX designer, and illustrator. \n" - "Keep the project on track. Don't hire any more people. \n" - "When a milestone is reached, stop and ask for customer feedback. Make sure the customer is satisfied. \n" - "Be VERY concise." - ) - ], - model_client=OpenAI(model="gpt-4-turbo"), - memory=HeadAndTailChatMemory(head_size=1, tail_size=10), - tools=[ - FunctionTool( - read_file, - name="read_file", - description="Read from a file.", - ), - FunctionTool(list_files, name="list_files", description="List files in a directory."), - FunctionTool(browse_web, name="browse_web", description="Browse a web page."), - ], - tool_approver=user_agent, + + product_manager = runtime.register_and_get( + "ProductManager", + lambda: ChatCompletionAgent( + description="A product manager. " + "Responsible for interfacing with the customer, planning and managing the project. ", + system_messages=[ + SystemMessage( + "You are a product manager. \n" + "You can browse files and directories. \n" + "You are entering a work session with the customer, developer, UX designer, and illustrator. \n" + "Keep the project on track. Don't hire any more people. \n" + "When a milestone is reached, stop and ask for customer feedback. Make sure the customer is satisfied. \n" + "Be VERY concise." + ) + ], + model_client=OpenAI(model="gpt-4-turbo"), + memory=HeadAndTailChatMemory(head_size=1, tail_size=10), + tools=[ + FunctionTool( + read_file, + name="read_file", + description="Read from a file.", + ), + FunctionTool(list_files, name="list_files", description="List files in a directory."), + FunctionTool(browse_web, name="browse_web", description="Browse a web page."), + ], + tool_approver=user_agent, + ), ) - ux_designer = ChatCompletionAgent( - name="UserExperienceDesigner", - description="A user experience designer for creating user interfaces.", - runtime=runtime, - system_messages=[ - SystemMessage( - "You are a user experience designer. \n" - "You can create user interfaces from descriptions. \n" - "You can browse files and directories. \n" - "You are entering a work session with the customer, developer, product manager, and illustrator. \n" - "When you are given a task, you should immediately start working on it. \n" - "Be concise and deliver now." - ) - ], - model_client=OpenAI(model="gpt-4-turbo"), - memory=HeadAndTailChatMemory(head_size=1, tail_size=10), - tools=[ - FunctionTool( - write_file, - name="write_file", - description="Write code to a file.", - ), - FunctionTool( - read_file, - name="read_file", - description="Read code from a file.", - ), - FunctionTool(list_files, name="list_files", description="List files in a directory."), - ], - tool_approver=user_agent, + ux_designer = runtime.register_and_get( + "UserExperienceDesigner", + lambda: ChatCompletionAgent( + description="A user experience designer for creating user interfaces.", + system_messages=[ + SystemMessage( + "You are a user experience designer. \n" + "You can create user interfaces from descriptions. \n" + "You can browse files and directories. \n" + "You are entering a work session with the customer, developer, product manager, and illustrator. \n" + "When you are given a task, you should immediately start working on it. \n" + "Be concise and deliver now." + ) + ], + model_client=OpenAI(model="gpt-4-turbo"), + memory=HeadAndTailChatMemory(head_size=1, tail_size=10), + tools=[ + FunctionTool( + write_file, + name="write_file", + description="Write code to a file.", + ), + FunctionTool( + read_file, + name="read_file", + description="Read code from a file.", + ), + FunctionTool(list_files, name="list_files", description="List files in a directory."), + ], + tool_approver=user_agent, + ), ) - illustrator = ChatCompletionAgent( - name="Illustrator", - description="An illustrator for creating images.", - runtime=runtime, - system_messages=[ - SystemMessage( - "You are an illustrator. " - "You can create images from descriptions. " - "You are entering a work session with the customer, developer, product manager, and UX designer. \n" - "When you are given a task, you should immediately start working on it. \n" - "Be concise and deliver now." - ) - ], - model_client=OpenAI(model="gpt-4-turbo"), - memory=HeadAndTailChatMemory(head_size=1, tail_size=10), - tools=[ - FunctionTool( - create_image, - name="create_image", - description="Create an image from a description.", - ), - ], - tool_approver=user_agent, + + illustrator = runtime.register_and_get( + "Illustrator", + lambda: ChatCompletionAgent( + description="An illustrator for creating images.", + system_messages=[ + SystemMessage( + "You are an illustrator. " + "You can create images from descriptions. " + "You are entering a work session with the customer, developer, product manager, and UX designer. \n" + "When you are given a task, you should immediately start working on it. \n" + "Be concise and deliver now." + ) + ], + model_client=OpenAI(model="gpt-4-turbo"), + memory=HeadAndTailChatMemory(head_size=1, tail_size=10), + tools=[ + FunctionTool( + create_image, + name="create_image", + description="Create an image from a description.", + ), + ], + tool_approver=user_agent, + ), ) - _ = GroupChatManager( - name="GroupChatManager", - description="A group chat manager.", - runtime=runtime, - memory=HeadAndTailChatMemory(head_size=1, tail_size=10), - model_client=OpenAI(model="gpt-4-turbo"), - participants=[developer.id, product_manager.id, ux_designer.id, illustrator.id, user_agent.id], + runtime.register( + "GroupChatManager", + lambda: GroupChatManager( + description="A group chat manager.", + runtime=runtime, + memory=HeadAndTailChatMemory(head_size=1, tail_size=10), + model_client=OpenAI(model="gpt-4-turbo"), + participants=[developer, product_manager, ux_designer, illustrator, user_agent], + ), ) art = r""" +----------------------------------------------------------+ diff --git a/examples/utils.py b/examples/utils.py index ba463a61eb7..bed060ffd78 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -130,7 +130,9 @@ async def publish_user_message(self, user_input: str) -> None: # Remove all typing messages. chat_messages.query("#typing").remove() # Publish the user message to the runtime. - await self._runtime.publish_message(TextMessage(source=self._user_name, content=user_input)) + await self._runtime.publish_message( + TextMessage(source=self._user_name, content=user_input), namespace="default" + ) async def post_runtime_message(self, message: TextMessage | MultiModalMessage) -> None: # type: ignore """Post a message from the agent runtime to the message list.""" @@ -151,8 +153,8 @@ async def handle_tool_approval_request(self, message: ToolApprovalRequest) -> To class TextualUserAgent(TypeRoutedAgent): # type: ignore """An agent that is used to receive messages from the runtime.""" - def __init__(self, name: str, description: str, runtime: AgentRuntime, app: TextualChatApp) -> None: # type: ignore - super().__init__(name, description, runtime) + def __init__(self, description: str, app: TextualChatApp) -> None: # type: ignore + super().__init__(description) self._app = app @message_handler # type: ignore diff --git a/src/agnext/application/_single_threaded_agent_runtime.py b/src/agnext/application/_single_threaded_agent_runtime.py index 915c064ceb6..aa4d3456e33 100644 --- a/src/agnext/application/_single_threaded_agent_runtime.py +++ b/src/agnext/application/_single_threaded_agent_runtime.py @@ -1,11 +1,14 @@ import asyncio +import inspect import logging +import threading from asyncio import Future +from collections import defaultdict from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Awaitable, Dict, List, Mapping, Set +from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast -from ..core import Agent, AgentId, AgentMetadata, AgentRuntime, CancellationToken +from ..core import Agent, AgentId, AgentMetadata, AgentProxy, AgentRuntime, AllNamespaces, BaseAgent, CancellationToken from ..core.exceptions import MessageDroppedException from ..core.intervention import DropMessage, InterventionHandler @@ -20,7 +23,8 @@ class PublishMessageEnvelope: message: Any cancellation_token: CancellationToken - sender: Agent | None + sender: AgentId | None + namespace: str @dataclass(kw_only=True) @@ -29,8 +33,8 @@ class SendMessageEnvelope: the message of the type T.""" message: Any - sender: Agent | None - recipient: Agent + sender: AgentId | None + recipient: AgentId future: Future[Any] cancellation_token: CancellationToken @@ -41,31 +45,45 @@ class ResponseMessageEnvelope: message: Any future: Future[Any] - sender: Agent - recipient: Agent | None + sender: AgentId + recipient: AgentId | None + + +P = ParamSpec("P") +T = TypeVar("T", bound=Agent) + + +class Counter: + def __init__(self) -> None: + self._count: int = 0 + self.threadLock = threading.Lock() + + def increment(self) -> None: + self.threadLock.acquire() + self._count += 1 + self.threadLock.release() + + def get(self) -> int: + return self._count + + def decrement(self) -> None: + self.threadLock.acquire() + self._count -= 1 + self.threadLock.release() class SingleThreadedAgentRuntime(AgentRuntime): def __init__(self, *, before_send: InterventionHandler | None = None) -> None: self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] - self._per_type_subscribers: Dict[type, List[Agent]] = {} - self._agents: Set[Agent] = set() + # (namespace, type) -> List[AgentId] + self._per_type_subscribers: DefaultDict[tuple[str, type], Set[AgentId]] = defaultdict(set) + self._agent_factories: Dict[str, Callable[[], Agent] | Callable[[AgentRuntime, AgentId], Agent]] = {} + # If empty, then all namespaces are valid for that agent type + self._valid_namespaces: Dict[str, Sequence[str]] = {} + self._instantiated_agents: Dict[AgentId, Agent] = {} self._before_send = before_send - - def add_agent(self, agent: Agent) -> None: - agent_names = {agent.metadata["name"] for agent in self._agents} - if agent.metadata["name"] in agent_names: - raise ValueError(f"Agent with name {agent.metadata['name']} already exists. Agent names must be unique.") - - for message_type in agent.metadata["subscriptions"]: - if message_type not in self._per_type_subscribers: - self._per_type_subscribers[message_type] = [] - self._per_type_subscribers[message_type].append(agent) - self._agents.add(agent) - - @property - def agents(self) -> Sequence[Agent]: - return list(self._agents) + self._known_namespaces: set[str] = set() + self._outstanding_tasks = Counter() @property def unprocessed_messages( @@ -73,13 +91,21 @@ def unprocessed_messages( ) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: return self._message_queue + @property + def outstanding_tasks(self) -> int: + return self._outstanding_tasks.get() + + @property + def _known_agent_names(self) -> Set[str]: + return set(self._agent_factories.keys()) + # Returns the response of the message def send_message( self, message: Any, - recipient: Agent | AgentId, + recipient: AgentId, *, - sender: Agent | AgentId | None = None, + sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Future[Any | None]: if cancellation_token is None: @@ -95,18 +121,18 @@ def send_message( # ) # ) - recipient = self._get_agent(recipient) - if sender is not None: - sender = self._get_agent(sender) - - logger.info( - f"Sending message of type {type(message).__name__} to {recipient.metadata['name']}: {message.__dict__}" - ) + if recipient.namespace not in self._known_namespaces: + self._prepare_namespace(recipient.namespace) future = asyncio.get_event_loop().create_future() - if recipient not in self._agents: + if recipient.name not in self._known_agent_names: future.set_exception(Exception("Recipient not found")) + if sender is not None and sender.namespace != recipient.namespace: + raise ValueError("Sender and recipient must be in the same namespace to communicate.") + + logger.info(f"Sending message of type {type(message).__name__} to {recipient.name}: {message.__dict__}") + self._message_queue.append( SendMessageEnvelope( message=message, @@ -123,15 +149,13 @@ def publish_message( self, message: Any, *, - sender: Agent | AgentId | None = None, + namespace: str | None = None, + sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Future[None]: if cancellation_token is None: cancellation_token = CancellationToken() - if sender is not None: - sender = self._get_agent(sender) - logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}") # event_logger.info( @@ -144,11 +168,27 @@ def publish_message( # ) # ) + if sender is None and namespace is None: + raise ValueError("Namespace must be provided if sender is not provided.") + + sender_namespace = sender.namespace if sender is not None else None + explicit_namespace = namespace + if explicit_namespace is not None and sender_namespace is not None and explicit_namespace != sender_namespace: + raise ValueError( + f"Explicit namespace {explicit_namespace} does not match sender namespace {sender_namespace}" + ) + + assert explicit_namespace is not None or sender_namespace is not None + namespace = cast(str, explicit_namespace or sender_namespace) + if namespace not in self._known_namespaces: + self._prepare_namespace(namespace) + self._message_queue.append( PublishMessageEnvelope( message=message, cancellation_token=cancellation_token, sender=sender, + namespace=namespace, ) ) @@ -158,22 +198,25 @@ def publish_message( def save_state(self) -> Mapping[str, Any]: state: Dict[str, Dict[str, Any]] = {} - for agent in self._agents: - state[agent.metadata["name"]] = dict(agent.save_state()) + for agent_id in self._instantiated_agents: + state[str(agent_id)] = dict(self._get_agent(agent_id).save_state()) return state def load_state(self, state: Mapping[str, Any]) -> None: - for agent in self._agents: - agent.load_state(state[agent.metadata["name"]]) + for agent_id_str in state: + agent_id = AgentId.from_str(agent_id_str) + if agent_id.name in self._known_agent_names: + self._get_agent(agent_id).load_state(state[str(agent_id)]) async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: recipient = message_envelope.recipient - assert recipient in self._agents + # todo: check if recipient is in the known namespaces + # assert recipient in self._agents try: - sender_name = message_envelope.sender.metadata["name"] if message_envelope.sender is not None else "Unknown" + sender_name = message_envelope.sender.name if message_envelope.sender is not None else "Unknown" logger.info( - f"Calling message handler for {recipient.metadata['name']} with message type {type(message_envelope.message).__name__} sent by {sender_name}" + f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}" ) # event_logger.info( # MessageEvent( @@ -184,7 +227,8 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: # delivery_stage=DeliveryStage.DELIVER, # ) # ) - response = await recipient.on_message( + recipient_agent = self._get_agent(recipient) + response = await recipient_agent.on_message( message_envelope.message, cancellation_token=message_envelope.cancellation_token, ) @@ -200,19 +244,19 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: recipient=message_envelope.sender, ) ) + self._outstanding_tasks.decrement() async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: responses: List[Awaitable[Any]] = [] - for agent in self._per_type_subscribers.get(type(message_envelope.message), []): # type: ignore - if ( - message_envelope.sender is not None - and agent.metadata["name"] == message_envelope.sender.metadata["name"] - ): + target_namespace = message_envelope.namespace + for agent_id in self._per_type_subscribers[(target_namespace, type(message_envelope.message))]: + if message_envelope.sender is not None and agent_id.name == message_envelope.sender.name: continue - sender_name = message_envelope.sender.metadata["name"] if message_envelope.sender is not None else "Unknown" + sender_agent = self._get_agent(message_envelope.sender) if message_envelope.sender is not None else None + sender_name = sender_agent.metadata["name"] if sender_agent is not None else "Unknown" logger.info( - f"Calling message handler for {agent.metadata['name']} with message type {type(message_envelope.message).__name__} published by {sender_name}" + f"Calling message handler for {agent_id.name} with message type {type(message_envelope.message).__name__} published by {sender_name}" ) # event_logger.info( # MessageEvent( @@ -224,6 +268,7 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No # ) # ) + agent = self._get_agent(agent_id) future = agent.on_message( message_envelope.message, cancellation_token=message_envelope.cancellation_token, @@ -236,19 +281,17 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No logger.error("Error processing publish message", exc_info=True) return + self._outstanding_tasks.decrement() # TODO if responses are given for a publish async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: - recipient_name = ( - message_envelope.recipient.metadata["name"] if message_envelope.recipient is not None else "Unknown" - ) content = ( message_envelope.message.__dict__ if hasattr(message_envelope.message, "__dict__") else message_envelope.message ) logger.info( - f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {recipient_name} from {message_envelope.sender.metadata['name']}: {content}" + f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.name}: {content}" ) # event_logger.info( # MessageEvent( @@ -259,6 +302,7 @@ async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> # delivery_stage=DeliveryStage.DELIVER, # ) # ) + self._outstanding_tasks.decrement() message_envelope.future.set_result(message_envelope.message) async def process_next(self) -> None: @@ -282,7 +326,7 @@ async def process_next(self) -> None: return message_envelope.message = temp_message - + self._outstanding_tasks.increment() asyncio.create_task(self._process_send(message_envelope)) case PublishMessageEnvelope( message=message, @@ -300,7 +344,7 @@ async def process_next(self) -> None: return message_envelope.message = temp_message - + self._outstanding_tasks.increment() asyncio.create_task(self._process_publish(message_envelope)) case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): if self._before_send is not None: @@ -315,27 +359,94 @@ async def process_next(self) -> None: return message_envelope.message = temp_message - + self._outstanding_tasks.increment() asyncio.create_task(self._process_response(message_envelope)) # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) - def agent_metadata(self, agent: Agent | AgentId) -> AgentMetadata: + def agent_metadata(self, agent: AgentId) -> AgentMetadata: return self._get_agent(agent).metadata - def agent_save_state(self, agent: Agent | AgentId) -> Mapping[str, Any]: + def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: return self._get_agent(agent).save_state() - def agent_load_state(self, agent: Agent | AgentId, state: Mapping[str, Any]) -> None: + def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: self._get_agent(agent).load_state(state) - def _get_agent(self, agent_or_id: Agent | AgentId) -> Agent: - if isinstance(agent_or_id, Agent): - return agent_or_id - - for agent in self._agents: - if agent.metadata["name"] == agent_or_id.name: - return agent - - raise ValueError(f"Agent with name {agent_or_id} not found") + def register( + self, + name: str, + agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + *, + valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, + ) -> None: + if name in self._agent_factories: + raise ValueError(f"Agent with name {name} already exists.") + self._agent_factories[name] = agent_factory + if valid_namespaces is not AllNamespaces: + self._valid_namespaces[name] = cast(Sequence[str], valid_namespaces) + else: + self._valid_namespaces[name] = [] + + def _invoke_agent_factory( + self, agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], agent_id: AgentId + ) -> T: + if len(inspect.signature(agent_factory).parameters) == 0: + factory_one = cast(Callable[[], T], agent_factory) + agent = factory_one() + elif len(inspect.signature(agent_factory).parameters) == 2: + factory_two = cast(Callable[[AgentRuntime, AgentId], T], agent_factory) + agent = factory_two(self, agent_id) + else: + raise ValueError("Agent factory must take 0 or 2 arguments.") + + # TODO: should this be part of the base agent interface? + if isinstance(agent, BaseAgent): + agent.bind_id(agent_id) + agent.bind_runtime(self) + + return agent + + def _type_valid_for_namespace(self, agent_id: AgentId) -> bool: + if agent_id.name not in self._agent_factories: + raise KeyError(f"Agent with name {agent_id.name} not found.") + + valid_namespaces = self._valid_namespaces[agent_id.name] + if len(valid_namespaces) == 0: + return True + + return agent_id.namespace in valid_namespaces + + def _get_agent(self, agent_id: AgentId) -> Agent: + if agent_id in self._instantiated_agents: + return self._instantiated_agents[agent_id] + + if not self._type_valid_for_namespace(agent_id): + raise ValueError(f"Agent with name {agent_id.name} not valid for namespace {agent_id.namespace}.") + + self._known_namespaces.add(agent_id.namespace) + if agent_id.name not in self._agent_factories: + raise ValueError(f"Agent with name {agent_id.name} not found.") + + agent_factory = self._agent_factories[agent_id.name] + + agent = self._invoke_agent_factory(agent_factory, agent_id) + for message_type in agent.metadata["subscriptions"]: + self._per_type_subscribers[(agent_id.namespace, message_type)].add(agent_id) + self._instantiated_agents[agent_id] = agent + return agent + + def get(self, name: str, *, namespace: str = "default") -> AgentId: + return self._get_agent(AgentId(name=name, namespace=namespace)).id + + def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: + id = self.get(name, namespace=namespace) + return AgentProxy(id, self) + + # Hydrate the agent instances in a namespace. The primary reason for this is + # to ensure message type subscriptions are set up. + def _prepare_namespace(self, namespace: str) -> None: + for name in self._known_agent_names: + if self._type_valid_for_namespace(AgentId(name=name, namespace=namespace)): + self._get_agent(AgentId(name=name, namespace=namespace)) diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index f522a804b46..1600fa688fc 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -14,7 +14,7 @@ SystemMessage, ) from ...components.tools import Tool -from ...core import Agent, AgentRuntime, CancellationToken +from ...core import AgentId, CancellationToken from ..memory import ChatMemory from ..types import ( FunctionCallMessage, @@ -59,16 +59,14 @@ class ChatCompletionAgent(TypeRoutedAgent): def __init__( self, - name: str, description: str, - runtime: AgentRuntime, system_messages: List[SystemMessage], memory: ChatMemory, model_client: ChatCompletionClient, tools: Sequence[Tool] = [], - tool_approver: Agent | None = None, + tool_approver: AgentId | None = None, ) -> None: - super().__init__(name, description, runtime) + super().__init__(description) self._description = description self._system_messages = system_messages self._client = model_client @@ -240,7 +238,7 @@ async def _execute_function( ) approval_response = await self._send_message( message=approval_request, - recipient=self._tool_approver.id, + recipient=self._tool_approver, cancellation_token=cancellation_token, ) if not isinstance(approval_response, ToolApprovalResponse): diff --git a/src/agnext/chat/agents/image_generation_agent.py b/src/agnext/chat/agents/image_generation_agent.py index 193f6ef6b64..cf1ffb7bc26 100644 --- a/src/agnext/chat/agents/image_generation_agent.py +++ b/src/agnext/chat/agents/image_generation_agent.py @@ -7,7 +7,7 @@ TypeRoutedAgent, message_handler, ) -from ...core import AgentRuntime, CancellationToken +from ...core import CancellationToken from ..memory import ChatMemory from ..types import ( MultiModalMessage, @@ -20,14 +20,12 @@ class ImageGenerationAgent(TypeRoutedAgent): def __init__( self, - name: str, description: str, - runtime: AgentRuntime, memory: ChatMemory, client: openai.AsyncClient, model: Literal["dall-e-2", "dall-e-3"] = "dall-e-2", ): - super().__init__(name, description, runtime) + super().__init__(description) self._client = client self._model = model self._memory = memory diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 6b405ef2929..070657e2315 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -5,7 +5,7 @@ from openai.types.beta import AssistantResponseFormatParam from ...components import TypeRoutedAgent, message_handler -from ...core import AgentRuntime, CancellationToken +from ...core import CancellationToken from ..types import PublishNow, Reset, RespondNow, ResponseFormat, TextMessage @@ -28,15 +28,13 @@ class OpenAIAssistantAgent(TypeRoutedAgent): def __init__( self, - name: str, description: str, - runtime: AgentRuntime, client: openai.AsyncClient, assistant_id: str, thread_id: str, assistant_event_handler_factory: Callable[[], AsyncAssistantEventHandler] | None = None, ) -> None: - super().__init__(name, description, runtime) + super().__init__(description) self._client = client self._assistant_id = assistant_id self._thread_id = thread_id diff --git a/src/agnext/chat/agents/user_proxy.py b/src/agnext/chat/agents/user_proxy.py index e650d66482b..9aeb213408a 100644 --- a/src/agnext/chat/agents/user_proxy.py +++ b/src/agnext/chat/agents/user_proxy.py @@ -1,7 +1,7 @@ import asyncio from ...components import TypeRoutedAgent, message_handler -from ...core import AgentRuntime, CancellationToken +from ...core import CancellationToken from ..types import PublishNow, TextMessage @@ -16,8 +16,8 @@ class UserProxyAgent(TypeRoutedAgent): user_input_prompt (str): The console prompt to show to the user when asking for input. """ - def __init__(self, name: str, description: str, runtime: AgentRuntime, user_input_prompt: str) -> None: - super().__init__(name, description, runtime) + def __init__(self, description: str, user_input_prompt: str) -> None: + super().__init__(description) self._user_input_prompt = user_input_prompt @message_handler() diff --git a/src/agnext/chat/patterns/group_chat_manager.py b/src/agnext/chat/patterns/group_chat_manager.py index 80cae0685dd..50886de7d26 100644 --- a/src/agnext/chat/patterns/group_chat_manager.py +++ b/src/agnext/chat/patterns/group_chat_manager.py @@ -41,7 +41,6 @@ class GroupChatManager(TypeRoutedAgent): def __init__( self, - name: str, description: str, runtime: AgentRuntime, participants: List[AgentId], @@ -51,7 +50,7 @@ def __init__( transitions: Mapping[AgentId, List[AgentId]] = {}, on_message_received: Callable[[TextMessage | MultiModalMessage], None] | None = None, ): - super().__init__(name, description, runtime) + super().__init__(description) self._memory = memory self._client = model_client self._participants = participants diff --git a/src/agnext/chat/patterns/orchestrator_chat.py b/src/agnext/chat/patterns/orchestrator_chat.py index 3ecc4d36932..7bb5c664d81 100644 --- a/src/agnext/chat/patterns/orchestrator_chat.py +++ b/src/agnext/chat/patterns/orchestrator_chat.py @@ -2,7 +2,7 @@ from typing import Any, Sequence, Tuple from ...components import TypeRoutedAgent, message_handler -from ...core import AgentId, AgentProxy, AgentRuntime, CancellationToken +from ...core import AgentId, AgentRuntime, CancellationToken from ..types import Reset, RespondNow, ResponseFormat, TextMessage __all__ = ["OrchestratorChat"] @@ -11,7 +11,6 @@ class OrchestratorChat(TypeRoutedAgent): def __init__( self, - name: str, description: str, runtime: AgentRuntime, orchestrator: AgentId, @@ -21,21 +20,17 @@ def __init__( max_stalled_turns_before_retry: int = 2, max_retry_attempts: int = 1, ) -> None: - super().__init__(name, description, runtime) - self._orchestrator = AgentProxy(orchestrator, runtime) - self._planner = AgentProxy(planner, runtime) - self._specialists = [AgentProxy(x, runtime) for x in specialists] + super().__init__(description) + self._orchestrator = orchestrator + self._planner = planner + self._specialists = specialists self._max_turns = max_turns self._max_stalled_turns_before_retry = max_stalled_turns_before_retry self._max_retry_attempts_before_educated_guess = max_retry_attempts @property - def children(self) -> Sequence[str]: - return ( - [agent.metadata["name"] for agent in self._specialists] - + [self._orchestrator.metadata["name"]] - + [self._planner.metadata["name"]] - ) + def children(self) -> Sequence[AgentId]: + return list(self._specialists) + [self._orchestrator, self._planner] @message_handler() async def on_text_message( @@ -55,7 +50,7 @@ async def on_text_message( while total_turns < self._max_turns: # Reset all agents. for agent in [*self._specialists, self._orchestrator]: - await self._send_message(Reset(), agent.id) + await self._send_message(Reset(), agent) # Create the task specs. task_specs = f""" @@ -77,7 +72,7 @@ async def on_text_message( # Send the task specs to the orchestrator and specialists. for agent in [*self._specialists, self._orchestrator]: - await self._send_message(TextMessage(content=task_specs, source=self.metadata["name"]), agent.id) + await self._send_message(TextMessage(content=task_specs, source=self.metadata["name"]), agent) # Inner loop. stalled_turns = 0 @@ -133,19 +128,17 @@ async def on_text_message( for agent in [*self._specialists, self._orchestrator]: _ = await self._send_message( TextMessage(content=subtask, source=self.metadata["name"]), - agent.id, + agent, ) # Find the speaker. try: - speaker = next( - agent for agent in self._specialists if agent.metadata["name"] == data["next_speaker"]["answer"] - ) + speaker = next(agent for agent in self._specialists if agent.name == data["next_speaker"]["answer"]) except StopIteration as e: raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e # Ask speaker to speak. - speaker_response = await self._send_message(RespondNow(), speaker.id) + speaker_response = await self._send_message(RespondNow(), speaker) assert speaker_response is not None # Update all other agents with the speaker's response. @@ -155,7 +148,7 @@ async def on_text_message( content=speaker_response.content, source=speaker_response.source, ), - agent.id, + agent, ) # Increment the total turns. @@ -168,11 +161,13 @@ async def on_text_message( async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]: # Reset planner. - await self._send_message(Reset(), self._planner.id) + await self._send_message(Reset(), self._planner) # A reusable description of the team. - team = "\n".join([agent.metadata["name"] + ": " + agent.metadata["description"] for agent in self._specialists]) - names = ", ".join([agent.metadata["name"] for agent in self._specialists]) + team = "\n".join( + [agent.name + ": " + self.runtime.agent_metadata(agent)["description"] for agent in self._specialists] + ) + names = ", ".join([agent.name for agent in self._specialists]) # A place to store relevant facts. facts = "" @@ -203,8 +198,8 @@ async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, st """.strip() # Ask the planner to obtain prior knowledge about facts. - await self._send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner.id) - facts_response = await self._send_message(RespondNow(), self._planner.id) + await self._send_message(TextMessage(content=closed_book_prompt, source=sender), self._planner) + facts_response = await self._send_message(RespondNow(), self._planner) facts = str(facts_response.content) @@ -216,8 +211,8 @@ async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, st Based on the team composition, and known and unknown facts, please devise a short bullet-point plan for addressing the original request. Remember, there is no requirement to involve all team members -- a team member's particular expertise may not be needed for this task.""".strip() # Send second messag eto the planner. - await self._send_message(TextMessage(content=plan_prompt, source=sender), self._planner.id) - plan_response = await self._send_message(RespondNow(), self._planner.id) + await self._send_message(TextMessage(content=plan_prompt, source=sender), self._planner) + plan_response = await self._send_message(RespondNow(), self._planner) plan = str(plan_response.content) return team, names, facts, plan @@ -269,11 +264,11 @@ async def _reflect_on_task( request = step_prompt while True: # Send a message to the orchestrator. - await self._send_message(TextMessage(content=request, source=sender), self._orchestrator.id) + await self._send_message(TextMessage(content=request, source=sender), self._orchestrator) # Request a response. step_response = await self._send_message( RespondNow(response_format=ResponseFormat.json_object), - self._orchestrator.id, + self._orchestrator, ) # TODO: use typed dictionary. try: @@ -332,9 +327,9 @@ async def _rewrite_facts(self, facts: str, sender: str) -> str: {facts} """.strip() # Send a message to the orchestrator. - await self._send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator.id) + await self._send_message(TextMessage(content=new_facts_prompt, source=sender), self._orchestrator) # Request a response. - new_facts_response = await self._send_message(RespondNow(), self._orchestrator.id) + new_facts_response = await self._send_message(RespondNow(), self._orchestrator) return str(new_facts_response.content) async def _educated_guess(self, facts: str, sender: str) -> Any: @@ -359,12 +354,12 @@ async def _educated_guess(self, facts: str, sender: str) -> Any: # Send a message to the orchestrator. await self._send_message( TextMessage(content=request, source=sender), - self._orchestrator.id, + self._orchestrator, ) # Request a response. response = await self._send_message( RespondNow(response_format=ResponseFormat.json_object), - self._orchestrator.id, + self._orchestrator, ) try: result = json.loads(str(response.content)) @@ -391,7 +386,7 @@ async def _rewrite_plan(self, team: str, sender: str) -> str: {team} """.strip() # Send a message to the orchestrator. - await self._send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator.id) + await self._send_message(TextMessage(content=new_plan_prompt, source=sender), self._orchestrator) # Request a response. - new_plan_response = await self._send_message(RespondNow(), self._orchestrator.id) + new_plan_response = await self._send_message(RespondNow(), self._orchestrator) return str(new_plan_response.content) diff --git a/src/agnext/components/_type_routed_agent.py b/src/agnext/components/_type_routed_agent.py index 72e6252ad59..b5c48fa8968 100644 --- a/src/agnext/components/_type_routed_agent.py +++ b/src/agnext/components/_type_routed_agent.py @@ -22,7 +22,7 @@ runtime_checkable, ) -from ..core import AgentRuntime, BaseAgent, CancellationToken +from ..core import BaseAgent, CancellationToken from ..core.exceptions import CantHandleException logger = logging.getLogger("agnext") @@ -162,7 +162,7 @@ async def wrapper(self: Any, message: ReceivesT, cancellation_token: Cancellatio class TypeRoutedAgent(BaseAgent): - def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None: + def __init__(self, description: str) -> None: # Self is already bound to the handlers self._handlers: Dict[ Type[Any], @@ -177,7 +177,7 @@ def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None: for target_type in message_handler.target_types: self._handlers[target_type] = message_handler subscriptions = list(self._handlers.keys()) - super().__init__(name, description, subscriptions, runtime) + super().__init__(description, subscriptions) async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: key_type: Type[Any] = type(message) # type: ignore diff --git a/src/agnext/core/__init__.py b/src/agnext/core/__init__.py index 37cd231a549..76f40438745 100644 --- a/src/agnext/core/__init__.py +++ b/src/agnext/core/__init__.py @@ -7,7 +7,7 @@ from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_proxy import AgentProxy -from ._agent_runtime import AgentRuntime +from ._agent_runtime import AgentRuntime, AllNamespaces from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken @@ -17,6 +17,7 @@ "AgentProxy", "AgentMetadata", "AgentRuntime", + "AllNamespaces", "BaseAgent", "CancellationToken", "AgentChildren", diff --git a/src/agnext/core/_agent_metadata.py b/src/agnext/core/_agent_metadata.py index 3d9b95e0db4..d3cb7259dd6 100644 --- a/src/agnext/core/_agent_metadata.py +++ b/src/agnext/core/_agent_metadata.py @@ -3,5 +3,6 @@ class AgentMetadata(TypedDict): name: str + namespace: str description: str subscriptions: Sequence[type] diff --git a/src/agnext/core/_agent_props.py b/src/agnext/core/_agent_props.py index 8e9ddbe4e6f..f2cb7f17c75 100644 --- a/src/agnext/core/_agent_props.py +++ b/src/agnext/core/_agent_props.py @@ -1,9 +1,11 @@ from typing import Protocol, Sequence, runtime_checkable +from ._agent_id import AgentId + @runtime_checkable class AgentChildren(Protocol): @property - def children(self) -> Sequence[str]: - """Names of the children of the agent.""" + def children(self) -> Sequence[AgentId]: + """Ids of the children of the agent.""" ... diff --git a/src/agnext/core/_agent_proxy.py b/src/agnext/core/_agent_proxy.py index aee4143368b..f53890bfc45 100644 --- a/src/agnext/core/_agent_proxy.py +++ b/src/agnext/core/_agent_proxy.py @@ -1,22 +1,25 @@ +from __future__ import annotations + from asyncio import Future -from typing import Any, Mapping +from typing import TYPE_CHECKING, Any, Mapping -from ._agent import Agent from ._agent_id import AgentId from ._agent_metadata import AgentMetadata -from ._agent_runtime import AgentRuntime from ._cancellation_token import CancellationToken +if TYPE_CHECKING: + from ._agent_runtime import AgentRuntime + class AgentProxy: - def __init__(self, agent: Agent | AgentId, runtime: AgentRuntime): + def __init__(self, agent: AgentId, runtime: AgentRuntime): self._agent = agent self._runtime = runtime @property def id(self) -> AgentId: """Target agent for this proxy""" - raise NotImplementedError + return self._agent @property def metadata(self) -> AgentMetadata: @@ -27,7 +30,7 @@ def send_message( self, message: Any, *, - sender: Agent, + sender: AgentId, cancellation_token: CancellationToken | None = None, ) -> Future[Any]: return self._runtime.send_message( diff --git a/src/agnext/core/_agent_runtime.py b/src/agnext/core/_agent_runtime.py index acba09c7df8..c1b60bc862f 100644 --- a/src/agnext/core/_agent_runtime.py +++ b/src/agnext/core/_agent_runtime.py @@ -1,33 +1,32 @@ +from __future__ import annotations + from asyncio import Future -from typing import Any, Mapping, Protocol +from typing import Any, Callable, Mapping, Protocol, Sequence, Type, TypeVar, overload, runtime_checkable from ._agent import Agent from ._agent_id import AgentId from ._agent_metadata import AgentMetadata +from ._agent_proxy import AgentProxy from ._cancellation_token import CancellationToken # Undeliverable - error +T = TypeVar("T", bound=Agent) -class AgentRuntime(Protocol): - def add_agent(self, agent: Agent) -> None: - """Add an agent to the runtime. - Args: - agent (Agent): Agent to add to the runtime. +class AllNamespaces: + pass - Note: - The name of the agent should be unique within the runtime. - """ - ... +@runtime_checkable +class AgentRuntime(Protocol): # Returns the response of the message def send_message( self, message: Any, - recipient: Agent | AgentId, + recipient: AgentId, *, - sender: Agent | AgentId | None = None, + sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Future[Any]: ... @@ -36,16 +35,104 @@ def publish_message( self, message: Any, *, - sender: Agent | AgentId | None = None, + namespace: str | None = None, + sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Future[None]: ... + @overload + def register( + self, name: str, agent_factory: Callable[[], T], *, valid_namespaces: Sequence[str] | Type[AllNamespaces] = ... + ) -> None: ... + + @overload + def register( + self, + name: str, + agent_factory: Callable[[AgentRuntime, AgentId], T], + *, + valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., + ) -> None: ... + + def register( + self, + name: str, + agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + *, + valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, + ) -> None: ... + + def get(self, name: str, *, namespace: str = "default") -> AgentId: ... + def get_proxy(self, name: str, *, namespace: str = "default") -> AgentProxy: ... + + @overload + def register_and_get( + self, + name: str, + agent_factory: Callable[[], T], + *, + namespace: str = "default", + valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., + ) -> AgentId: ... + + @overload + def register_and_get( + self, + name: str, + agent_factory: Callable[[AgentRuntime, AgentId], T], + *, + namespace: str = "default", + valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., + ) -> AgentId: ... + + def register_and_get( + self, + name: str, + agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + *, + namespace: str = "default", + valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, + ) -> AgentId: + self.register(name, agent_factory) + return self.get(name, namespace=namespace) + + @overload + def register_and_get_proxy( + self, + name: str, + agent_factory: Callable[[], T], + *, + namespace: str = "default", + valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., + ) -> AgentProxy: ... + + @overload + def register_and_get_proxy( + self, + name: str, + agent_factory: Callable[[AgentRuntime, AgentId], T], + *, + namespace: str = "default", + valid_namespaces: Sequence[str] | Type[AllNamespaces] = ..., + ) -> AgentProxy: ... + + def register_and_get_proxy( + self, + name: str, + agent_factory: Callable[[], T] | Callable[[AgentRuntime, AgentId], T], + *, + namespace: str = "default", + valid_namespaces: Sequence[str] | Type[AllNamespaces] = AllNamespaces, + ) -> AgentProxy: + self.register(name, agent_factory) + return self.get_proxy(name, namespace=namespace) + def save_state(self) -> Mapping[str, Any]: ... def load_state(self, state: Mapping[str, Any]) -> None: ... - def agent_metadata(self, agent: Agent | AgentId) -> AgentMetadata: ... + def agent_metadata(self, agent: AgentId) -> AgentMetadata: ... - def agent_save_state(self, agent: Agent | AgentId) -> Mapping[str, Any]: ... + def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]: ... - def agent_load_state(self, agent: Agent | AgentId, state: Mapping[str, Any]) -> None: ... + def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None: ... diff --git a/src/agnext/core/_base_agent.py b/src/agnext/core/_base_agent.py index 313d37dbb00..7bbb63a4e65 100644 --- a/src/agnext/core/_base_agent.py +++ b/src/agnext/core/_base_agent.py @@ -11,24 +11,50 @@ class BaseAgent(ABC, Agent): - def __init__(self, name: str, description: str, subscriptions: Sequence[type], router: AgentRuntime) -> None: - self._name = name - self._description = description - self._runtime = router - self._subscriptions = subscriptions - router.add_agent(self) - @property def metadata(self) -> AgentMetadata: + assert self._id is not None return AgentMetadata( - name=self._name, + namespace=self._id.namespace, + name=self._id.name, description=self._description, subscriptions=self._subscriptions, ) + def __init__(self, description: str, subscriptions: Sequence[type]) -> None: + self._runtime: AgentRuntime | None = None + self._id: AgentId | None = None + self._description = description + self._subscriptions = subscriptions + + def bind_runtime(self, runtime: AgentRuntime) -> None: + if self._runtime is not None: + raise RuntimeError("Agent has already been bound to a runtime.") + + self._runtime = runtime + + def bind_id(self, agent_id: AgentId) -> None: + if self._id is not None: + raise RuntimeError("Agent has already been bound to an id.") + self._id = agent_id + + @property + def name(self) -> str: + return self.id.name + @property def id(self) -> AgentId: - return AgentId(self._name, namespace="") + if self._id is None: + raise RuntimeError("Agent has not been bound to an id.") + + return self._id + + @property + def runtime(self) -> AgentRuntime: + if self._runtime is None: + raise RuntimeError("Agent has not been bound to a runtime.") + + return self._runtime @abstractmethod async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: ... @@ -41,12 +67,15 @@ def _send_message( *, cancellation_token: CancellationToken | None = None, ) -> Future[Any]: + if self._runtime is None: + raise RuntimeError("Agent has not been bound to a runtime.") + if cancellation_token is None: cancellation_token = CancellationToken() future = self._runtime.send_message( message, - sender=self, + sender=self.id, recipient=recipient, cancellation_token=cancellation_token, ) @@ -59,9 +88,13 @@ def _publish_message( *, cancellation_token: CancellationToken | None = None, ) -> Future[None]: + if self._runtime is None: + raise RuntimeError("Agent has not been bound to a runtime.") + if cancellation_token is None: cancellation_token = CancellationToken() - future = self._runtime.publish_message(message, sender=self, cancellation_token=cancellation_token) + + future = self._runtime.publish_message(message, sender=self.id, cancellation_token=cancellation_token) return future def save_state(self) -> Mapping[str, Any]: diff --git a/src/agnext/core/intervention.py b/src/agnext/core/intervention.py index 62606d63f47..afa8370854a 100644 --- a/src/agnext/core/intervention.py +++ b/src/agnext/core/intervention.py @@ -1,6 +1,6 @@ from typing import Any, Awaitable, Callable, Protocol, final -from agnext.core import Agent +from agnext.core import AgentId __all__ = [ "DropMessage", @@ -18,17 +18,19 @@ class DropMessage: ... class InterventionHandler(Protocol): - async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: ... - async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: ... - async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: ... + async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ... + async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ... + async def on_response( + self, message: Any, *, sender: AgentId, recipient: AgentId | None + ) -> Any | type[DropMessage]: ... class DefaultInterventionHandler(InterventionHandler): - async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: + async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: return message - async def on_publish(self, message: Any, *, sender: Agent | None) -> Any | type[DropMessage]: + async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: return message - async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: + async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]: return message diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 1c0a72d6631..0f43dc3f858 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -4,7 +4,7 @@ import pytest from agnext.application import SingleThreadedAgentRuntime from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentId, AgentRuntime, CancellationToken +from agnext.core import AgentId, CancellationToken @dataclass @@ -15,14 +15,14 @@ class MessageType: # To do cancellation, only the token should be interacted with as a user # If you cancel a future, it may not work as you expect. -class LongRunningAgent(TypeRoutedAgent): # type: ignore - def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore - super().__init__(name, "A long running agent", router) +class LongRunningAgent(TypeRoutedAgent): + def __init__(self) -> None: + super().__init__("A long running agent") self.called = False self.cancelled = False @message_handler - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) cancellation_token.link_future(sleep) @@ -33,15 +33,15 @@ async def on_new_message(self, message: MessageType, cancellation_token: Cancell self.cancelled = True raise -class NestingLongRunningAgent(TypeRoutedAgent): # type: ignore - def __init__(self, name: str, router: AgentRuntime, nested_agent: AgentId) -> None: # type: ignore - super().__init__(name, "A nesting long running agent", router) +class NestingLongRunningAgent(TypeRoutedAgent): + def __init__(self, nested_agent: AgentId) -> None: + super().__init__("A nesting long running agent") self.called = False self.cancelled = False self._nested_agent = nested_agent @message_handler - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.called = True response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token) try: @@ -55,69 +55,74 @@ async def on_new_message(self, message: MessageType, cancellation_token: Cancell @pytest.mark.asyncio async def test_cancellation_with_token() -> None: - router = SingleThreadedAgentRuntime() + runtime = SingleThreadedAgentRuntime() - long_running = LongRunningAgent("name", router) + long_running = runtime.register_and_get("long_running", LongRunningAgent) token = CancellationToken() - response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token) + response = runtime.send_message(MessageType(), recipient=long_running, cancellation_token=token) assert not response.done() - await router.process_next() + await runtime.process_next() token.cancel() with pytest.raises(asyncio.CancelledError): await response assert response.done() - assert long_running.called - assert long_running.cancelled + long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore + assert long_running_agent.called + assert long_running_agent.cancelled @pytest.mark.asyncio async def test_nested_cancellation_only_outer_called() -> None: - router = SingleThreadedAgentRuntime() + runtime = SingleThreadedAgentRuntime() - long_running = LongRunningAgent("name", router) - nested = NestingLongRunningAgent("nested", router, long_running.id) + long_running = runtime.register_and_get("long_running", LongRunningAgent) + nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) token = CancellationToken() - response = router.send_message(MessageType(), nested, cancellation_token=token) + response = runtime.send_message(MessageType(), nested, cancellation_token=token) assert not response.done() - await router.process_next() + await runtime.process_next() token.cancel() with pytest.raises(asyncio.CancelledError): await response assert response.done() - assert nested.called - assert nested.cancelled - assert long_running.called is False - assert long_running.cancelled is False + nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore + assert nested_agent.called + assert nested_agent.cancelled + long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore + assert long_running_agent.called is False + assert long_running_agent.cancelled is False @pytest.mark.asyncio async def test_nested_cancellation_inner_called() -> None: - router = SingleThreadedAgentRuntime() + runtime = SingleThreadedAgentRuntime() - long_running = LongRunningAgent("name", router) - nested = NestingLongRunningAgent("nested", router, long_running.id) + long_running = runtime.register_and_get("long_running", LongRunningAgent ) + nested = runtime.register_and_get("nested", lambda: NestingLongRunningAgent(long_running)) token = CancellationToken() - response = router.send_message(MessageType(), nested, cancellation_token=token) + response = runtime.send_message(MessageType(), nested, cancellation_token=token) assert not response.done() - await router.process_next() + await runtime.process_next() # allow the inner agent to process - await router.process_next() + await runtime.process_next() token.cancel() with pytest.raises(asyncio.CancelledError): await response assert response.done() - assert nested.called - assert nested.cancelled - assert long_running.called - assert long_running.cancelled + nested_agent: NestingLongRunningAgent = runtime._get_agent(nested) # type: ignore + assert nested_agent.called + assert nested_agent.cancelled + long_running_agent: LongRunningAgent = runtime._get_agent(long_running) # type: ignore + assert long_running_agent.called + assert long_running_agent.cancelled diff --git a/tests/test_intervention.py b/tests/test_intervention.py index 9c2edf878a6..ae3683e8b7c 100644 --- a/tests/test_intervention.py +++ b/tests/test_intervention.py @@ -1,63 +1,47 @@ -from dataclasses import dataclass - import pytest from agnext.application import SingleThreadedAgentRuntime -from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import Agent, AgentRuntime, CancellationToken +from agnext.core import AgentId from agnext.core.exceptions import MessageDroppedException from agnext.core.intervention import DefaultInterventionHandler, DropMessage +from test_utils import LoopbackAgent, MessageType -@dataclass -class MessageType: - ... - -class LoopbackAgent(TypeRoutedAgent): # type: ignore - def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore - super().__init__(name, "A loop back agent.", router) - self.num_calls = 0 - - - @message_handler() # type: ignore - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore - self.num_calls += 1 - return message - @pytest.mark.asyncio async def test_intervention_count_messages() -> None: - class DebugInterventionHandler(DefaultInterventionHandler): # type: ignore + class DebugInterventionHandler(DefaultInterventionHandler): def __init__(self) -> None: self.num_messages = 0 - async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: # type: ignore + async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType: self.num_messages += 1 return message handler = DebugInterventionHandler() runtime = SingleThreadedAgentRuntime(before_send=handler) + loopback = runtime.register_and_get("name", LoopbackAgent) - long_running = LoopbackAgent("name", runtime) - response = runtime.send_message(MessageType(), recipient=long_running) + response = runtime.send_message(MessageType(), recipient=loopback) while not response.done(): await runtime.process_next() assert handler.num_messages == 1 - assert long_running.num_calls == 1 + loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore + assert loopback_agent.num_calls == 1 @pytest.mark.asyncio async def test_intervention_drop_send() -> None: - class DropSendInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: # type: ignore - return DropMessage # type: ignore + class DropSendInterventionHandler(DefaultInterventionHandler): + async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]: + return DropMessage handler = DropSendInterventionHandler() runtime = SingleThreadedAgentRuntime(before_send=handler) - long_running = LoopbackAgent("name", runtime) - response = runtime.send_message(MessageType(), recipient=long_running) + loopback = runtime.register_and_get("name", LoopbackAgent) + response = runtime.send_message(MessageType(), recipient=loopback) while not response.done(): await runtime.process_next() @@ -65,21 +49,22 @@ async def on_send(self, message: MessageType, *, sender: Agent | None, recipient with pytest.raises(MessageDroppedException): await response - assert long_running.num_calls == 0 + loopback_agent: LoopbackAgent = runtime._get_agent(loopback) # type: ignore + assert loopback_agent.num_calls == 0 @pytest.mark.asyncio async def test_intervention_drop_response() -> None: - class DropResponseInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: # type: ignore - return DropMessage # type: ignore + class DropResponseInterventionHandler(DefaultInterventionHandler): + async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]: + return DropMessage handler = DropResponseInterventionHandler() runtime = SingleThreadedAgentRuntime(before_send=handler) - long_running = LoopbackAgent("name", runtime) - response = runtime.send_message(MessageType(), recipient=long_running) + loopback = runtime.register_and_get("name", LoopbackAgent) + response = runtime.send_message(MessageType(), recipient=loopback) while not response.done(): await runtime.process_next() @@ -87,7 +72,6 @@ async def on_response(self, message: MessageType, *, sender: Agent, recipient: A with pytest.raises(MessageDroppedException): await response - assert long_running.num_calls == 1 @pytest.mark.asyncio async def test_intervention_raise_exception_on_send() -> None: @@ -96,13 +80,13 @@ class InterventionException(Exception): pass class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: # type: ignore + async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType | type[DropMessage]: # type: ignore raise InterventionException handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(before_send=handler) - long_running = LoopbackAgent("name", runtime) + long_running = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=long_running) while not response.done(): @@ -111,7 +95,8 @@ async def on_send(self, message: MessageType, *, sender: Agent | None, recipient with pytest.raises(InterventionException): await response - assert long_running.num_calls == 0 + long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore + assert long_running_agent.num_calls == 0 @pytest.mark.asyncio async def test_intervention_raise_exception_on_respond() -> None: @@ -120,13 +105,13 @@ class InterventionException(Exception): pass class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: # type: ignore + async def on_response(self, message: MessageType, *, sender: AgentId, recipient: AgentId | None) -> MessageType | type[DropMessage]: # type: ignore raise InterventionException handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(before_send=handler) - long_running = LoopbackAgent("name", runtime) + long_running = runtime.register_and_get("name", LoopbackAgent) response = runtime.send_message(MessageType(), recipient=long_running) while not response.done(): @@ -135,4 +120,5 @@ async def on_response(self, message: MessageType, *, sender: Agent, recipient: A with pytest.raises(InterventionException): await response - assert long_running.num_calls == 1 + long_running_agent: LoopbackAgent = runtime._get_agent(long_running) # type: ignore + assert long_running_agent.num_calls == 1 diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 319a59c99f8..ea5798c25e5 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -1,31 +1,75 @@ -from typing import Any, Sequence +from typing import Any import pytest from agnext.application import SingleThreadedAgentRuntime -from agnext.core import AgentRuntime, BaseAgent, CancellationToken +from agnext.core import BaseAgent, CancellationToken +from test_utils import LoopbackAgent, MessageType -class NoopAgent(BaseAgent): # type: ignore - def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore - super().__init__(name, "A no op agent", [], runtime) +class NoopAgent(BaseAgent): # type: ignore + def __init__(self) -> None: # type: ignore + super().__init__("A no op agent", []) - @property - def subscriptions(self) -> Sequence[type]: - return [] - - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore raise NotImplementedError - @pytest.mark.asyncio async def test_agent_names_must_be_unique() -> None: runtime = SingleThreadedAgentRuntime() - _agent1 = NoopAgent("name1", runtime) + _agent1 = runtime.register_and_get("name1", NoopAgent) with pytest.raises(ValueError): - _agent1_again = NoopAgent("name1", runtime) + _agent1 = runtime.register_and_get("name1", NoopAgent) + + _agent1 = runtime.register_and_get("name3", NoopAgent) + +@pytest.mark.asyncio +async def test_register_receives_publish() -> None: + runtime = SingleThreadedAgentRuntime() + + runtime.register("name", LoopbackAgent) + await runtime.publish_message(MessageType(), namespace="default") - _agent3 = NoopAgent("name3", runtime) + while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0: + await runtime.process_next() + # Agent in default namespace should have received the message + long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore + assert long_running_agent.num_calls == 1 + + # Agent in other namespace should not have received the message + other_long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name", namespace="other")) # type: ignore + assert other_long_running_agent.num_calls == 0 + + + +@pytest.mark.asyncio +async def test_try_instantiate_agent_invalid_namespace() -> None: + runtime = SingleThreadedAgentRuntime() + + runtime.register("name", LoopbackAgent, valid_namespaces=["default"]) + await runtime.publish_message(MessageType(), namespace="non_default") + + while len(runtime.unprocessed_messages) > 0 or runtime.outstanding_tasks > 0: + await runtime.process_next() + + # Agent in default namespace should have received the message + long_running_agent: LoopbackAgent = runtime._get_agent(runtime.get("name")) # type: ignore + assert long_running_agent.num_calls == 0 + + with pytest.raises(ValueError): + _agent = runtime.get("name", namespace="non_default") + +@pytest.mark.asyncio +async def test_send_crosses_namepace() -> None: + runtime = SingleThreadedAgentRuntime() + + runtime.register("name", LoopbackAgent) + + default_ns_agent = runtime.get("name") + non_default_ns_agent = runtime.get("name", namespace="non_default") + + with pytest.raises(ValueError): + await runtime.send_message(MessageType(), default_ns_agent, sender=non_default_ns_agent) diff --git a/tests/test_state.py b/tests/test_state.py index f72a7189390..08c63a38d47 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -2,12 +2,12 @@ import pytest from agnext.application import SingleThreadedAgentRuntime -from agnext.core import AgentRuntime, BaseAgent, CancellationToken +from agnext.core import BaseAgent, CancellationToken class StatefulAgent(BaseAgent): # type: ignore - def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore - super().__init__(name, "A stateful agent", [], runtime) + def __init__(self) -> None: # type: ignore + super().__init__("A stateful agent", []) self.state = 0 @property @@ -28,7 +28,8 @@ def load_state(self, state: Mapping[str, Any]) -> None: async def test_agent_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - agent1 = StatefulAgent("name1", runtime) + agent1_id = runtime.register_and_get("name1", StatefulAgent) + agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore assert agent1.state == 0 agent1.state = 1 assert agent1.state == 1 @@ -45,7 +46,8 @@ async def test_agent_can_save_state() -> None: async def test_runtime_can_save_state() -> None: runtime = SingleThreadedAgentRuntime() - agent1 = StatefulAgent("name1", runtime) + agent1_id = runtime.register_and_get("name1", StatefulAgent) + agent1: StatefulAgent = runtime._get_agent(agent1_id) # type: ignore assert agent1.state == 0 agent1.state = 1 assert agent1.state == 1 @@ -53,7 +55,9 @@ async def test_runtime_can_save_state() -> None: runtime_state = runtime.save_state() runtime2 = SingleThreadedAgentRuntime() - agent2 = StatefulAgent("name1", runtime2) + agent2_id = runtime2.register_and_get("name1", StatefulAgent) + agent2: StatefulAgent = runtime2._get_agent(agent2_id) # type: ignore + runtime2.load_state(runtime_state) assert agent2.state == 1 diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 00000000000..12b6797740d --- /dev/null +++ b/tests/test_utils/__init__.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +from agnext.components import TypeRoutedAgent, message_handler +from agnext.core import CancellationToken + + +@dataclass +class MessageType: + ... + +class LoopbackAgent(TypeRoutedAgent): + def __init__(self) -> None: + super().__init__("A loop back agent.") + self.num_calls = 0 + + + @message_handler + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + self.num_calls += 1 + return message \ No newline at end of file