Skip to content

Commit

Permalink
Implement intervention (autogenhub#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored May 20, 2024
1 parent 5afbadb commit 77c8cca
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 34 deletions.
123 changes: 97 additions & 26 deletions src/agnext/application_components/single_threaded_agent_runtime.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,62 @@
import asyncio
from asyncio import Future
from dataclasses import dataclass
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar, cast

from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import MessageDroppedException
from agnext.core.intervention import DropMessage, InterventionHandler

from ..core.agent import Agent
from ..core.agent_runtime import AgentRuntime

T = TypeVar("T")


@dataclass
@dataclass(kw_only=True)
class BroadcastMessageEnvelope(Generic[T]):
"""A message envelope for broadcasting messages to all agents that can handle
the message of the type T."""

message: T
future: Future[List[T]]
cancellation_token: CancellationToken
sender: Agent[T] | None


@dataclass
@dataclass(kw_only=True)
class SendMessageEnvelope(Generic[T]):
"""A message envelope for sending a message to a specific agent that can handle
the message of the type T."""

message: T
destination: Agent[T]
sender: Agent[T] | None
recipient: Agent[T]
future: Future[T]
cancellation_token: CancellationToken


@dataclass
@dataclass(kw_only=True)
class ResponseMessageEnvelope(Generic[T]):
"""A message envelope for sending a response to a message."""

message: T
future: Future[T]
sender: Agent[T]
recipient: Agent[T] | None


@dataclass
@dataclass(kw_only=True)
class BroadcastResponseMessageEnvelope(Generic[T]):
"""A message envelope for sending a response to a message."""

message: List[T]
future: Future[List[T]]
recipient: Agent[T] | None


class SingleThreadedAgentRuntime(AgentRuntime[T]):
def __init__(self) -> None:
def __init__(self, *, before_send: InterventionHandler[T] | None = None) -> None:
self._message_queue: List[
BroadcastMessageEnvelope[T]
| SendMessageEnvelope[T]
Expand All @@ -58,6 +65,7 @@ def __init__(self) -> None:
] = []
self._per_type_subscribers: Dict[Type[T], List[Agent[T]]] = {}
self._agents: Set[Agent[T]] = set()
self._before_send = before_send

def add_agent(self, agent: Agent[T]) -> None:
for message_type in agent.subscriptions:
Expand All @@ -68,29 +76,48 @@ def add_agent(self, agent: Agent[T]) -> None:

# Returns the response of the message
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
self,
message: T,
recipient: Agent[T],
*,
sender: Agent[T] | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()

loop = asyncio.get_event_loop()
future: Future[T] = loop.create_future()

self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token))
self._message_queue.append(
SendMessageEnvelope(
message=message,
recipient=recipient,
future=future,
cancellation_token=cancellation_token,
sender=sender,
)
)

return future

# Returns the response of all handling agents
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
def broadcast_message(
self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None
) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()

future: Future[List[T]] = asyncio.get_event_loop().create_future()
self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token))
self._message_queue.append(
BroadcastMessageEnvelope(
message=message, future=future, cancellation_token=cancellation_token, sender=sender
)
)
return future

async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
recipient = message_envelope.destination
recipient = message_envelope.recipient
if recipient not in self._agents:
message_envelope.future.set_exception(Exception("Recipient not found"))
return
Expand All @@ -103,7 +130,14 @@ async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
message_envelope.future.set_exception(e)
return

self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))
self._message_queue.append(
ResponseMessageEnvelope(
message=response,
future=message_envelope.future,
sender=message_envelope.recipient,
recipient=message_envelope.sender,
)
)

async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None:
responses: List[Awaitable[T]] = []
Expand All @@ -117,7 +151,11 @@ async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]
message_envelope.future.set_exception(e)
return

self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))
self._message_queue.append(
BroadcastResponseMessageEnvelope(
message=all_responses, future=message_envelope.future, recipient=message_envelope.sender
)
)

async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
message_envelope.future.set_result(message_envelope.message)
Expand All @@ -134,18 +172,51 @@ async def process_next(self) -> None:
message_envelope = self._message_queue.pop(0)

match message_envelope:
case SendMessageEnvelope(message, destination, future, cancellation_token):
asyncio.create_task(
self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token))
)
case BroadcastMessageEnvelope(message, future, cancellation_token):
asyncio.create_task(
self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token))
)
case ResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
case BroadcastResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_broadcast_response(BroadcastResponseMessageEnvelope(message, future)))
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return

message_envelope.message = cast(T, temp_message)

asyncio.create_task(self._process_send(message_envelope))
case BroadcastMessageEnvelope(
message=message,
sender=sender,
future=future,
):
if self._before_send is not None:
temp_message = await self._before_send.on_broadcast(message, sender=sender)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return

message_envelope.message = cast(T, temp_message)

asyncio.create_task(self._process_broadcast(message_envelope))
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._before_send is not None:
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
future.set_exception(MessageDroppedException())
return

message_envelope.message = cast(T, temp_message)

asyncio.create_task(self._process_response(message_envelope))

case BroadcastResponseMessageEnvelope(message=message, recipient=recipient, future=future):
if self._before_send is not None:
temp_message_list = await self._before_send.on_broadcast_response(message, recipient=recipient)
if temp_message_list is DropMessage or isinstance(temp_message_list, DropMessage):
future.set_exception(MessageDroppedException())
return

message_envelope.message = list(temp_message_list) # type: ignore

asyncio.create_task(self._process_broadcast_response(message_envelope))

# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)
11 changes: 9 additions & 2 deletions src/agnext/core/agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@ def add_agent(self, agent: Agent[T]) -> None: ...

# Returns the response of the message
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
self,
message: T,
recipient: Agent[T],
*,
sender: Agent[T] | None = None,
cancellation_token: CancellationToken | None,
) -> Future[T]: ...

# Returns the response of all handling agents
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ...
def broadcast_message(
self, message: T, *, sender: Agent[T] | None = None, cancellation_token: CancellationToken | None = None
) -> Future[List[T]]: ...
8 changes: 5 additions & 3 deletions src/agnext/core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,19 @@ def subscriptions(self) -> Sequence[Type[T]]:
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...

def _send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
self, message: T, recipient: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.send_message(message, destination, cancellation_token)
future = self._router.send_message(
message, sender=self, recipient=recipient, cancellation_token=cancellation_token
)
cancellation_token.link_future(future)
return future

def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.broadcast_message(message, cancellation_token)
future = self._router.broadcast_message(message, sender=self, cancellation_token=cancellation_token)
cancellation_token.link_future(future)
return future
4 changes: 4 additions & 0 deletions src/agnext/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ class CantHandleException(Exception):

class UndeliverableException(Exception):
"""Raised when a message can't be delivered."""


class MessageDroppedException(Exception):
"""Raised when a message is dropped."""
39 changes: 39 additions & 0 deletions src/agnext/core/intervention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Awaitable, Callable, Protocol, Sequence, TypeVar, final

from agnext.core.agent import Agent


@final
class DropMessage: ...


T = TypeVar("T")

InterventionFunction = Callable[[T], T | Awaitable[type[DropMessage]]]


class InterventionHandler(Protocol[T]):
async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]: ...
async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]: ...
async def on_response(
self, message: T, *, sender: Agent[T], recipient: Agent[T] | None
) -> T | type[DropMessage]: ...
async def on_broadcast_response(
self, message: Sequence[T], *, recipient: Agent[T] | None
) -> Sequence[T] | type[DropMessage]: ...


class DefaultInterventionHandler(InterventionHandler[T]):
async def on_send(self, message: T, *, sender: Agent[T] | None, recipient: Agent[T]) -> T | type[DropMessage]:
return message

async def on_broadcast(self, message: T, *, sender: Agent[T] | None) -> T | type[DropMessage]:
return message

async def on_response(self, message: T, *, sender: Agent[T], recipient: Agent[T] | None) -> T | type[DropMessage]:
return message

async def on_broadcast_response(
self, message: Sequence[T], *, recipient: Agent[T] | None
) -> Sequence[T] | type[DropMessage]:
return message
2 changes: 2 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ echo "--- Running pyright ---"
pyright
echo "--- Running mypy ---"
mypy
echo "--- Running pytest ---"
pytest
6 changes: 3 additions & 3 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def test_cancellation_with_token() -> None:

long_running = LongRunningAgent("name", router)
token = CancellationToken()
response = router.send_message(MessageType(), long_running, token)
response = router.send_message(MessageType(), recipient=long_running, cancellation_token=token)
assert not response.done()

await router.process_next()
Expand All @@ -81,7 +81,7 @@ async def test_nested_cancellation_only_outer_called() -> None:
nested = NestingLongRunningAgent("nested", router, long_running)

token = CancellationToken()
response = router.send_message(MessageType(), nested, token)
response = router.send_message(MessageType(), nested, cancellation_token=token)
assert not response.done()

await router.process_next()
Expand All @@ -104,7 +104,7 @@ async def test_nested_cancellation_inner_called() -> None:
nested = NestingLongRunningAgent("nested", router, long_running)

token = CancellationToken()
response = router.send_message(MessageType(), nested, token)
response = router.send_message(MessageType(), nested, cancellation_token=token)
assert not response.done()

await router.process_next()
Expand Down
Loading

0 comments on commit 77c8cca

Please sign in to comment.