From f32f9eea02cd9537de166675787046b6beb7504a Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 15 May 2024 12:31:13 -0400 Subject: [PATCH] migrate prototype to initial impl --- .github/workflows/checks.yml | 18 ++-- examples/example.py | 137 ----------------------------- examples/futures.py | 52 +++++++++++ pyproject.toml | 6 +- src/agnext/core/__init__.py | 0 src/agnext/core/agent.py | 15 ++++ src/agnext/core/base_agent.py | 34 +++++++ src/agnext/core/exceptions.py | 6 ++ src/agnext/core/message.py | 6 ++ src/agnext/core/message_router.py | 20 +++++ src/agnext/prototype.py | 130 --------------------------- src/agnext/queue_message_router.py | 90 +++++++++++++++++++ src/agnext/type_routed_agent.py | 48 ++++++++++ 13 files changed, 283 insertions(+), 279 deletions(-) delete mode 100644 examples/example.py create mode 100644 examples/futures.py create mode 100644 src/agnext/core/__init__.py create mode 100644 src/agnext/core/agent.py create mode 100644 src/agnext/core/base_agent.py create mode 100644 src/agnext/core/exceptions.py create mode 100644 src/agnext/core/message.py create mode 100644 src/agnext/core/message_router.py delete mode 100644 src/agnext/prototype.py create mode 100644 src/agnext/queue_message_router.py create mode 100644 src/agnext/type_routed_agent.py diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ce337074969..e92b6a0e795 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -51,12 +51,12 @@ jobs: - run: pip install ".[dev]" - uses: jakebailey/pyright-action@v2 - test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - run: pip install ".[dev]" - - run: pytest + # test: + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # - uses: actions/setup-python@v5 + # with: + # python-version: '3.10' + # - run: pip install ".[dev]" + # - run: pytest diff --git a/examples/example.py b/examples/example.py deleted file mode 100644 index dc3625f5c7a..00000000000 --- a/examples/example.py +++ /dev/null @@ -1,137 +0,0 @@ -import asyncio -import random -from dataclasses import dataclass -from typing import Awaitable, Callable, List, Optional, Sequence, cast - -from agnext.prototype import Agent, Event, EventQueue, EventRouter, TypeRoutedAgent, event_handler - - -@dataclass -class InputEvent(Event): - message: str - sender: str - - -@dataclass -class NewEvent(Event): - message: str - sender: str - recipient: str - - -@dataclass -class ResponseEvent(Event): - message: Optional[str] - sender: str - - -GroupChatEvents = InputEvent | NewEvent | ResponseEvent - - -class GroupChatManager(TypeRoutedAgent[GroupChatEvents]): - def __init__( - self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]], agents: Sequence[Agent] - ) -> None: - super().__init__(name, emit_event) - self._agents = agents - self._current_speaker = 0 - self._events: List[GroupChatEvents] = [] - self._responses: List[ResponseEvent] = [] - - @event_handler(InputEvent) - async def on_input_event(self, event: InputEvent) -> None: - # New group chat - self._events.clear() - - recipient_agent = self._agents[self._current_speaker] - self._current_speaker = (self._current_speaker + 1) % len(self._agents) - - new_event = NewEvent(message=event.message, sender=self.name, recipient=recipient_agent.name) - self._events.append(event) - await self.emit_event(new_event) - - @event_handler(ResponseEvent) - async def on_group_chat_event(self, event: ResponseEvent) -> None: - self._responses.append(event) - - # TODO: Handle termination and replying to original sender - - # Received response from all - proceeed - if len(self._responses) == len(self._agents): - recipient_agent = self._agents[self._current_speaker] - self._current_speaker = (self._current_speaker + 1) % len(self._agents) - - responses_with_content = [x for x in self._responses if x.message is not None] - if len(responses_with_content) != 1: - raise ValueError("Can't handle anything other than 1 response right now.") - - new_event = NewEvent( - message=cast(str, responses_with_content[0].message), sender=self.name, recipient=recipient_agent.name - ) - self._events.append(new_event) - self._responses.clear() - await self.emit_event(new_event) - - async def on_unhandled_event(self, event: GroupChatEvents) -> None: - raise ValueError("Unknown") - - -class Critic(TypeRoutedAgent[GroupChatEvents]): - def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]]) -> None: - super().__init__(name, emit_event) - - @event_handler(NewEvent) - async def on_new_event(self, event: NewEvent) -> None: - if event.recipient == self.name: - response = random.choice([" is a good idea", " is a bad idea"]) - await self.emit_event(ResponseEvent(event.message + response, sender=self.name)) - else: - await self.emit_event(ResponseEvent(None, sender=self.name)) - - async def on_unhandled_event(self, event: GroupChatEvents) -> None: - raise ValueError("Unknown") - - -class Suggester(TypeRoutedAgent[GroupChatEvents]): - def __init__(self, name: str, emit_event: Callable[[GroupChatEvents], Awaitable[None]]) -> None: - super().__init__(name, emit_event) - - @event_handler(NewEvent) - async def on_new_event(self, event: NewEvent) -> None: - if event.recipient == self.name: - response = random.choice( - ["Attach wheels to a laptop", "merge a banana and an apple", "Cheese but made with oats"] - ) - await self.emit_event(ResponseEvent(response, sender=self.name)) - else: - await self.emit_event(ResponseEvent(None, sender=self.name)) - - async def on_unhandled_event(self, event: GroupChatEvents) -> None: - raise ValueError("Unknown") - - -async def main(): - event_queue = EventQueue[GroupChatEvents]() - - critic = Critic("Critic", event_queue.into_callable()) - suggester = Suggester("Suggester", event_queue.into_callable()) - group_chat_manager = GroupChatManager("Manager", event_queue.into_callable(), [critic, suggester]) - processor = EventRouter(event_queue, [critic, suggester, group_chat_manager]) - - await event_queue.emit(InputEvent(message="Go", sender="external")) - - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - await processor.process_next() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/futures.py b/examples/futures.py new file mode 100644 index 00000000000..55d21e73531 --- /dev/null +++ b/examples/futures.py @@ -0,0 +1,52 @@ +import asyncio +from dataclasses import dataclass + +from agnext.core.agent import Agent +from agnext.core.message import Message +from agnext.core.message_router import MessageRouter +from agnext.queue_message_router import QueueMessageRouter +from agnext.type_routed_agent import TypeRoutedAgent, event_handler + + +@dataclass +class MessageType(Message): + message: str + sender: str + + +class Inner(TypeRoutedAgent[MessageType]): + def __init__(self, name: str, router: MessageRouter[MessageType]) -> None: + super().__init__(name, router) + + @event_handler(MessageType) + async def on_new_event(self, event: MessageType) -> MessageType: + return MessageType(message=f"Inner: {event.message}", sender=self.name) + + +class Outer(TypeRoutedAgent[MessageType]): + def __init__(self, name: str, router: MessageRouter[MessageType], inner: Agent[MessageType]) -> None: + super().__init__(name, router) + self._inner = inner + + @event_handler(MessageType) + async def on_new_event(self, event: MessageType) -> MessageType: + inner_response = self._send_message(event, self._inner) + inner_message = await inner_response + return MessageType(message=f"Outer: {inner_message.message}", sender=self.name) + + +async def main() -> None: + router = QueueMessageRouter[MessageType]() + + inner = Inner("inner", router) + outer = Outer("outer", router, inner) + response = router.send_message(MessageType(message="Hello", sender="external"), outer) + + while not response.done(): + await router.process_next() + + print(await response) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index d806d686782..cddb136e1fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,14 +25,14 @@ line-length = 120 fix = true exclude = ["build", "dist", "my_project/__init__.py", "my_project/main.py"] target-version = "py310" -include = ["src/**", "examples/**", "tests/**"] +include = ["src/**", "examples/**"] [tool.ruff.lint] select = ["E", "F", "W", "B", "Q", "I"] ignore = ["F401", "E501"] [tool.mypy] -files = ["src", "examples", "tests"] +files = ["src", "examples"] strict = true python_version = "3.10" @@ -51,7 +51,7 @@ disallow_untyped_decorators = true disallow_any_unimported = true [tool.pyright] -include = ["src", "examples", "tests"] +include = ["src", "examples"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false diff --git a/src/agnext/core/__init__.py b/src/agnext/core/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/agnext/core/agent.py b/src/agnext/core/agent.py new file mode 100644 index 00000000000..165fe2c8ade --- /dev/null +++ b/src/agnext/core/agent.py @@ -0,0 +1,15 @@ +from typing import Protocol, Sequence, Type, TypeVar + +from .message import Message + +T = TypeVar("T", bound=Message) + + +class Agent(Protocol[T]): + @property + def name(self) -> str: ... + + @property + def subscriptions(self) -> Sequence[Type[T]]: ... + + async def on_event(self, event: T) -> T: ... diff --git a/src/agnext/core/base_agent.py b/src/agnext/core/base_agent.py new file mode 100644 index 00000000000..d23e5534ac1 --- /dev/null +++ b/src/agnext/core/base_agent.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from asyncio import Future +from typing import List, Sequence, Type, TypeVar + +from agnext.core.message_router import MessageRouter + +from .agent import Agent +from .message import Message + +T = TypeVar("T", bound=Message) + + +class BaseAgent(ABC, Agent[T]): + def __init__(self, name: str, router: MessageRouter[T]) -> None: + self._name = name + self._router = router + + @property + def name(self) -> str: + return self._name + + @property + @abstractmethod + def subscriptions(self) -> Sequence[Type[T]]: + return [] + + @abstractmethod + async def on_event(self, event: T) -> T: ... + + def _send_message(self, message: T, destination: Agent[T]) -> Future[T]: + return self._router.send_message(message, destination) + + def _broadcast_message(self, message: T) -> Future[List[T]]: + return self._router.broadcast_message(message) diff --git a/src/agnext/core/exceptions.py b/src/agnext/core/exceptions.py new file mode 100644 index 00000000000..8d305d54e05 --- /dev/null +++ b/src/agnext/core/exceptions.py @@ -0,0 +1,6 @@ +class CantHandleException(Exception): + """Raised when a handler can't handle the exception.""" + + +class UndeliverableException(Exception): + """Raised when a message can't be delivered.""" diff --git a/src/agnext/core/message.py b/src/agnext/core/message.py new file mode 100644 index 00000000000..8edf09146c3 --- /dev/null +++ b/src/agnext/core/message.py @@ -0,0 +1,6 @@ +from typing import Protocol + + +class Message(Protocol): + sender: str + # reply_to: Optional[str] diff --git a/src/agnext/core/message_router.py b/src/agnext/core/message_router.py new file mode 100644 index 00000000000..fc9cbb013d0 --- /dev/null +++ b/src/agnext/core/message_router.py @@ -0,0 +1,20 @@ +from asyncio import Future +from typing import List, Protocol, TypeVar + +from agnext.core.agent import Agent + +from .message import Message + +T = TypeVar("T", bound=Message) + +# Undeliverable - error + + +class MessageRouter(Protocol[T]): + def add_agent(self, agent: Agent[T]) -> None: ... + + # Returns the response of the message + def send_message(self, message: T, destination: Agent[T]) -> Future[T]: ... + + # Returns the response of all handling agents + def broadcast_message(self, message: T) -> Future[List[T]]: ... diff --git a/src/agnext/prototype.py b/src/agnext/prototype.py deleted file mode 100644 index 4e49eb06f61..00000000000 --- a/src/agnext/prototype.py +++ /dev/null @@ -1,130 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Awaitable, Callable, Dict, List, Protocol, Sequence, Type - -# Type based routing for event - -# -metadata -# Receipt -# Request response -# Type/Kind - - -# DELIVERY RECEIPTS - - -# on event -# on event with receipt - - -class Event(Protocol): - sender: str - # reply_to: Optional[str] - - -# T must encompass all subscribed types for a given agent - - -class Agent(Protocol): - @property - def name(self) -> str: ... - - -class EventBasedAgent[T: Event](Agent): - @property - def subscriptions(self) -> Sequence[Type[T]]: ... - - async def on_event(self, event: T) -> None: ... - - # async def _send_event(self, event: T) -> None: - # ... - - # async def _broadcast_message(self, event: T) -> None: - # ... - - -# NOTE: this works on concrete types and not inheritance -def event_handler[T: Event](target_type: Type[T]): - def decorator(func: Callable[..., Awaitable[None]]) -> Callable[..., Awaitable[None]]: - func._target_type = target_type # type: ignore - return func - - return decorator - - -class TypeRoutedAgent[T: Event](EventBasedAgent[T], ABC): - def __init__(self, name: str, emit_event: Callable[[T], Awaitable[None]]) -> None: - self._name = name - self._handlers: Dict[Type[Any], Callable[[T], Awaitable[None]]] = {} - self._emit_event = emit_event - - for attr in dir(self): - if callable(getattr(self, attr)): - handler = getattr(self, attr) - if hasattr(handler, "_target_type"): - # TODO do i need to partially apply self? - self._handlers[handler._target_type] = handler - - @property - def name(self) -> str: - return self._name - - @property - def subscriptions(self) -> Sequence[Type[T]]: - return list(self._handlers.keys()) - - async def emit_event(self, event: T) -> None: - await self._emit_event(event) - - async def on_event(self, event: T) -> None: - handler = self._handlers.get(type(event)) - if handler is not None: - await handler(event) - else: - await self.on_unhandled_event(event) - - @abstractmethod - async def on_unhandled_event(self, event: T) -> None: ... - - -class EventQueue[U]: - def __init__(self) -> None: - self._queue: List[U] = [] - - async def emit(self, event: U) -> None: - print(event) - self._queue.append(event) - - def pop_event(self) -> U: - return self._queue.pop(0) - - def empty(self) -> bool: - return len(self._queue) == 0 - - def into_callable(self) -> Callable[[U], Awaitable[None]]: - return self.emit - - -class EventRouter[T: Event]: - def __init__(self, event_queue: EventQueue[T], agents: Sequence[EventBasedAgent[T]]) -> None: - self._event_queue = event_queue - # Use default dict i just cant remember the syntax and im without internet - self._per_type_subscribers: Dict[Type[T], List[EventBasedAgent[T]]] = {} - for agent in agents: - subscriptions = agent.subscriptions - for subscription in subscriptions: - if subscription not in self._per_type_subscribers: - self._per_type_subscribers[subscription] = [] - - self._per_type_subscribers[subscription].append(agent) - - async def process_next(self) -> None: - if self._event_queue.empty(): - return - - event = self._event_queue.pop_event() - subscribers = self._per_type_subscribers.get(type(event)) - if subscribers is not None: - for subscriber in subscribers: - await subscriber.on_event(event) - else: - print(f"Event {event} has no recipient agent") diff --git a/src/agnext/queue_message_router.py b/src/agnext/queue_message_router.py new file mode 100644 index 00000000000..2980b3f36a1 --- /dev/null +++ b/src/agnext/queue_message_router.py @@ -0,0 +1,90 @@ +import asyncio +from asyncio import Future +from dataclasses import dataclass +from typing import Dict, Generic, List, Set, Type, TypeVar + +from agnext.core.agent import Agent + +from .core.message import Message +from .core.message_router import MessageRouter + +T = TypeVar("T", bound=Message) + + +@dataclass +class BroadcastMessage(Generic[T]): + message: T + future: Future[List[T]] + + +@dataclass +class SendMessage(Generic[T]): + message: T + destination: Agent[T] + future: Future[T] + + +@dataclass +class ResponseMessage(Generic[T]): ... + + +class QueueMessageRouter(MessageRouter[T]): + def __init__(self) -> None: + self._event_queue: List[BroadcastMessage[T] | SendMessage[T]] = [] + self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {} + self._agents: Set[Agent[T]] = set() + + def add_agent(self, agent: Agent[T]) -> None: + for event_type in agent.subscriptions: + if event_type not in self._per_type_subscribers: + self._per_type_subscribers[event_type] = [] + self._per_type_subscribers[event_type].append(agent) + self._agents.add(agent) + + # Returns the response of the message + def send_message(self, message: T, destination: Agent[T]) -> Future[T]: + loop = asyncio.get_event_loop() + future: Future[T] = loop.create_future() + + self._event_queue.append(SendMessage(message, destination, future)) + + return future + + # Returns the response of all handling agents + def broadcast_message(self, message: T) -> Future[List[T]]: + future: Future[List[T]] = asyncio.get_event_loop().create_future() + self._event_queue.append(BroadcastMessage(message, future)) + return future + + async def _process_send(self, message: SendMessage[T]) -> None: + recipient = message.destination + if recipient not in self._agents: + message.future.set_exception(Exception("Recipient not found")) + return + + response = await recipient.on_event(message.message) + message.future.set_result(response) + + async def _process_broadcast(self, message: BroadcastMessage[T]) -> None: + responses: List[T] = [] + for agent in self._per_type_subscribers.get(type(message.message), []): + response = await agent.on_event(message.message) + responses.append(response) + message.future.set_result(responses) + + async def process_next(self) -> None: + if len(self._event_queue) == 0: + # Yield control to the event loop to allow other tasks to run + await asyncio.sleep(0) + return + + event = self._event_queue.pop(0) + + match event: + case SendMessage(message, destination, future): + asyncio.create_task(self._process_send(SendMessage(message, destination, future))) + case BroadcastMessage(message, future): + asyncio.create_task(self._process_broadcast(BroadcastMessage(message, future))) + + # Yield control to the event loop to allow other tasks to run + await asyncio.sleep(0) diff --git a/src/agnext/type_routed_agent.py b/src/agnext/type_routed_agent.py new file mode 100644 index 00000000000..5c8ef06821e --- /dev/null +++ b/src/agnext/type_routed_agent.py @@ -0,0 +1,48 @@ +from typing import Any, Awaitable, Callable, Dict, Sequence, Type, TypeVar + +from agnext.core.base_agent import BaseAgent +from agnext.core.exceptions import CantHandleException +from agnext.core.message_router import MessageRouter + +from .core.message import Message + +T = TypeVar("T", bound=Message) + + +# NOTE: this works on concrete types and not inheritance +def event_handler(target_type: Type[T]) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]: + def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + func._target_type = target_type # type: ignore + return func + + return decorator + + +class TypeRoutedAgent(BaseAgent[T]): + def __init__(self, name: str, router: MessageRouter[T]) -> None: + super().__init__(name, router) + + self._handlers: Dict[Type[Any], Callable[[T], Awaitable[T]]] = {} + + router.add_agent(self) + + for attr in dir(self): + if callable(getattr(self, attr)): + handler = getattr(self, attr) + if hasattr(handler, "_target_type"): + # TODO do i need to partially apply self? + self._handlers[handler._target_type] = handler + + @property + def subscriptions(self) -> Sequence[Type[T]]: + return list(self._handlers.keys()) + + async def on_event(self, event: T) -> T: + handler = self._handlers.get(type(event)) + if handler is not None: + return await handler(event) + else: + return await self.on_unhandled_event(event) + + async def on_unhandled_event(self, event: T) -> T: + raise CantHandleException()