diff --git a/src/agnext/application_components/single_threaded_agent_runtime.py b/src/agnext/application_components/single_threaded_agent_runtime.py index bda10272c82b..749f431f37d4 100644 --- a/src/agnext/application_components/single_threaded_agent_runtime.py +++ b/src/agnext/application_components/single_threaded_agent_runtime.py @@ -1,9 +1,11 @@ import asyncio from asyncio import Future from dataclasses import dataclass -from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar +from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar, cast from agnext.core.cancellation_token import CancellationToken +from agnext.core.exceptions import MessageDroppedException +from agnext.core.intervention import DropMessage, InterventionHandler from ..core.agent import Agent from ..core.agent_runtime import AgentRuntime @@ -11,7 +13,7 @@ T = TypeVar("T") -@dataclass +@dataclass(kw_only=True) class BroadcastMessageEnvelope(Generic[T]): """A message envelope for broadcasting messages to all agents that can handle the message of the type T.""" @@ -19,37 +21,42 @@ class BroadcastMessageEnvelope(Generic[T]): message: T future: Future[List[T]] cancellation_token: CancellationToken + sender: Agent[T] | None -@dataclass +@dataclass(kw_only=True) class SendMessageEnvelope(Generic[T]): """A message envelope for sending a message to a specific agent that can handle the message of the type T.""" message: T - destination: Agent[T] + sender: Agent[T] | None + recipient: Agent[T] future: Future[T] cancellation_token: CancellationToken -@dataclass +@dataclass(kw_only=True) class ResponseMessageEnvelope(Generic[T]): """A message envelope for sending a response to a message.""" message: T future: Future[T] + sender: Agent[T] + recipient: Agent[T] | None -@dataclass +@dataclass(kw_only=True) class BroadcastResponseMessageEnvelope(Generic[T]): """A message envelope for sending a response to a message.""" message: List[T] future: Future[List[T]] + recipient: Agent[T] | None class SingleThreadedAgentRuntime(AgentRuntime[T]): - def __init__(self) -> None: + def __init__(self, *, before_send: InterventionHandler[T] | None = None) -> None: self._message_queue: List[ BroadcastMessageEnvelope[T] | SendMessageEnvelope[T] @@ -58,6 +65,7 @@ def __init__(self) -> None: ] = [] self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {} self._agents: Set[Agent[T]] = set() + self._before_send = before_send def add_agent(self, agent: Agent[T]) -> None: for message_type in agent.subscriptions: @@ -68,7 +76,12 @@ def add_agent(self, agent: Agent[T]) -> None: # Returns the response of the message def send_message( - self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None + self, + message: T, + recipient: Agent[T], + *, + sender: Agent[T] | None = None, + cancellation_token: CancellationToken | None = None, ) -> Future[T]: if cancellation_token is None: cancellation_token = CancellationToken() @@ -76,21 +89,35 @@ def send_message( loop = asyncio.get_event_loop() future: Future[T] = loop.create_future() - self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token)) + self._message_queue.append( + SendMessageEnvelope( + message=message, + recipient=recipient, + future=future, + cancellation_token=cancellation_token, + sender=sender, + ) + ) return future # Returns the response of all handling agents - def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: + def broadcast_message( + self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None + ) -> Future[List[T]]: if cancellation_token is None: cancellation_token = CancellationToken() future: Future[List[T]] = asyncio.get_event_loop().create_future() - self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token)) + self._message_queue.append( + BroadcastMessageEnvelope( + message=message, future=future, cancellation_token=cancellation_token, sender=sender + ) + ) return future async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None: - recipient = message_envelope.destination + recipient = message_envelope.recipient if recipient not in self._agents: message_envelope.future.set_exception(Exception("Recipient not found")) return @@ -103,7 +130,14 @@ async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None: message_envelope.future.set_exception(e) return - self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future)) + self._message_queue.append( + ResponseMessageEnvelope( + message=response, + future=message_envelope.future, + sender=message_envelope.recipient, + recipient=message_envelope.sender, + ) + ) async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None: responses: List[Awaitable[T]] = [] @@ -117,7 +151,11 @@ async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T] message_envelope.future.set_exception(e) return - self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future)) + self._message_queue.append( + BroadcastResponseMessageEnvelope( + message=all_responses, future=message_envelope.future, recipient=message_envelope.sender + ) + ) async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None: message_envelope.future.set_result(message_envelope.message) @@ -134,18 +172,51 @@ async def process_next(self) -> None: message_envelope = self._message_queue.pop(0) match message_envelope: - case SendMessageEnvelope(message, destination, future, cancellation_token): - asyncio.create_task( - self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token)) - ) - case BroadcastMessageEnvelope(message, future, cancellation_token): - asyncio.create_task( - self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token)) - ) - case ResponseMessageEnvelope(message, future): - asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future))) - case BroadcastResponseMessageEnvelope(message, future): - asyncio.create_task(self._process_broadcast_response(BroadcastResponseMessageEnvelope(message, future))) + case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): + if self._before_send is not None: + temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient) + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + future.set_exception(MessageDroppedException()) + return + + message_envelope.message = cast(T, temp_message) + + asyncio.create_task(self._process_send(message_envelope)) + case BroadcastMessageEnvelope( + message=message, + sender=sender, + future=future, + ): + if self._before_send is not None: + temp_message = await self._before_send.on_broadcast(message, sender=sender) + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + future.set_exception(MessageDroppedException()) + return + + message_envelope.message = cast(T, temp_message) + + asyncio.create_task(self._process_broadcast(message_envelope)) + case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): + if self._before_send is not None: + temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient) + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + future.set_exception(MessageDroppedException()) + return + + message_envelope.message = cast(T, temp_message) + + asyncio.create_task(self._process_response(message_envelope)) + + case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future): + if self._before_send is not None: + temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient) + if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage): + future.set_exception(MessageDroppedException()) + return + + message_envelope.message = list(temp_message_list) # type: ignore + + asyncio.create_task(self._process_broadcast_response(message_envelope)) # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) diff --git a/src/agnext/core/agent_runtime.py b/src/agnext/core/agent_runtime.py index 6c936cf7084b..aae215246293 100644 --- a/src/agnext/core/agent_runtime.py +++ b/src/agnext/core/agent_runtime.py @@ -14,8 +14,15 @@ def add_agent(self, agent: Agent[T]) -> None: ... # Returns the response of the message def send_message( - self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None + self, + message: T, + recipient: Agent[T], + *, + sender: Agent[T] | None = None, + cancellation_token: CancellationToken | None, ) -> Future[T]: ... # Returns the response of all handling agents - def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ... + def broadcast_message( + self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None + ) -> Future[List[T]]: ... diff --git a/src/agnext/core/base_agent.py b/src/agnext/core/base_agent.py index 5282c59e2e48..ee90258e50bc 100644 --- a/src/agnext/core/base_agent.py +++ b/src/agnext/core/base_agent.py @@ -28,17 +28,19 @@ def subscriptions(self) -> Sequence[Type[T]]: async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ... def _send_message( - self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None + self, message: T, recipient: Agent[T], cancellation_token: CancellationToken | None = None ) -> Future[T]: if cancellation_token is None: cancellation_token = CancellationToken() - future = self._router.send_message(message, destination, cancellation_token) + future = self._router.send_message( + message, sender=self, recipient=recipient, cancellation_token=cancellation_token + ) cancellation_token.link_future(future) return future def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: if cancellation_token is None: cancellation_token = CancellationToken() - future = self._router.broadcast_message(message, cancellation_token) + future = self._router.broadcast_message(message, sender=self, cancellation_token=cancellation_token) cancellation_token.link_future(future) return future diff --git a/src/agnext/core/exceptions.py b/src/agnext/core/exceptions.py index 8d305d54e058..c3b21b60ec39 100644 --- a/src/agnext/core/exceptions.py +++ b/src/agnext/core/exceptions.py @@ -4,3 +4,7 @@ class CantHandleException(Exception): class UndeliverableException(Exception): """Raised when a message can't be delivered.""" + + +class MessageDroppedException(Exception): + """Raised when a message is dropped.""" diff --git a/src/agnext/core/intervention.py b/src/agnext/core/intervention.py new file mode 100644 index 000000000000..a8e5833eb226 --- /dev/null +++ b/src/agnext/core/intervention.py @@ -0,0 +1,39 @@ +from typing import Awaitable, Callable, Protocol, Sequence, TypeVar, final + +from agnext.core.agent import Agent + + +@final +class DropMessage: ... + + +T = TypeVar("T") + +InterventionFunction = Callable[[T], T | Awaitable[type[DropMessage]]] + + +class InterventionHandler(Protocol[T]): + async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: ... + async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: ... + async def on_response( + self, message: T, *, sender: Agent[T], recipient: Agent[T] | None + ) -> T | type[DropMessage]: ... + async def on_broadcast_response( + self, message: Sequence[T], *, recipient: Agent[T] | None + ) -> Sequence[T] | type[DropMessage]: ... + + +class DefaultInterventionHandler(InterventionHandler[T]): + async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: + return message + + async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: + return message + + async def on_response(self, message: T, *, sender: Agent[T], recipient: Agent[T] | None) -> T | type[DropMessage]: + return message + + async def on_broadcast_response( + self, message: Sequence[T], *, recipient: Agent[T] | None + ) -> Sequence[T] | type[DropMessage]: + return message diff --git a/test.sh b/test.sh index bc7fff177887..636d66382fa4 100755 --- a/test.sh +++ b/test.sh @@ -19,3 +19,5 @@ echo "--- Running pyright ---" pyright echo "--- Running mypy ---" mypy +echo "--- Running pytest ---" +pytest \ No newline at end of file diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index f5996bfcbf43..25380604335b 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -58,7 +58,7 @@ async def test_cancellation_with_token() -> None: long_running = LongRunningAgent("name", router) token = CancellationToken() - response = router.send_message(MessageType(), long_running, token) + response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token) assert not response.done() await router.process_next() @@ -81,7 +81,7 @@ async def test_nested_cancellation_only_outer_called() -> None: nested = NestingLongRunningAgent("nested", router, long_running) token = CancellationToken() - response = router.send_message(MessageType(), nested, token) + response = router.send_message(MessageType(), nested, cancellation_token=token) assert not response.done() await router.process_next() @@ -104,7 +104,7 @@ async def test_nested_cancellation_inner_called() -> None: nested = NestingLongRunningAgent("nested", router, long_running) token = CancellationToken() - response = router.send_message(MessageType(), nested, token) + response = router.send_message(MessageType(), nested, cancellation_token=token) assert not response.done() await router.process_next() diff --git a/tests/test_intervention.py b/tests/test_intervention.py new file mode 100644 index 000000000000..79a5c5b03e9f --- /dev/null +++ b/tests/test_intervention.py @@ -0,0 +1,91 @@ +import pytest +from dataclasses import dataclass + +from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler +from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime +from agnext.core.agent import Agent +from agnext.core.agent_runtime import AgentRuntime +from agnext.core.cancellation_token import CancellationToken +from agnext.core.exceptions import MessageDroppedException +from agnext.core.intervention import DefaultInterventionHandler, DropMessage + +@dataclass +class MessageType: + ... + +class LoopbackAgent(TypeRoutedAgent[MessageType]): + def __init__(self, name: str, router: AgentRuntime[MessageType]) -> None: + super().__init__(name, router) + self.num_calls = 0 + + + @message_handler(MessageType) + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + self.num_calls += 1 + return message + +@pytest.mark.asyncio +async def test_intervention_count_messages() -> None: + + class DebugInterventionHandler(DefaultInterventionHandler[MessageType]): + def __init__(self): + self.num_messages = 0 + + async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType: + self.num_messages += 1 + return message + + handler = DebugInterventionHandler() + router = SingleThreadedAgentRuntime[MessageType](before_send=handler) + + long_running = LoopbackAgent("name", router) + response = router.send_message(MessageType(), recipient=long_running) + + while not response.done(): + await router.process_next() + + assert handler.num_messages == 1 + assert long_running.num_calls == 1 + +@pytest.mark.asyncio +async def test_intervention_drop_send() -> None: + + class DropSendInterventionHandler(DefaultInterventionHandler[MessageType]): + async def on_send(self, message: MessageType, *, sender: Agent[MessageType] | None, recipient: Agent[MessageType]) -> MessageType | type[DropMessage]: + return DropMessage + + handler = DropSendInterventionHandler() + router = SingleThreadedAgentRuntime[MessageType](before_send=handler) + + long_running = LoopbackAgent("name", router) + response = router.send_message(MessageType(), recipient=long_running) + + while not response.done(): + await router.process_next() + + with pytest.raises(MessageDroppedException): + await response + + assert long_running.num_calls == 0 + + +@pytest.mark.asyncio +async def test_intervention_drop_response() -> None: + + class DropResponseInterventionHandler(DefaultInterventionHandler[MessageType]): + async def on_response(self, message: MessageType, *, sender: Agent[MessageType], recipient: Agent[MessageType] | None) -> MessageType | type[DropMessage]: + return DropMessage + + handler = DropResponseInterventionHandler() + router = SingleThreadedAgentRuntime[MessageType](before_send=handler) + + long_running = LoopbackAgent("name", router) + response = router.send_message(MessageType(), recipient=long_running) + + while not response.done(): + await router.process_next() + + with pytest.raises(MessageDroppedException): + await response + + assert long_running.num_calls == 1