Skip to content

Commit

Permalink
Add support for task cancellation (autogenhub#7)
Browse files Browse the repository at this point in the history
* Add support for task cancellation

* add tests to CI

* matrix for python testing
  • Loading branch information
jackgerrits authored May 20, 2024
1 parent f80c42e commit 5afbadb
Show file tree
Hide file tree
Showing 13 changed files with 267 additions and 66 deletions.
21 changes: 12 additions & 9 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,15 @@ jobs:
- run: pip install ".[dev]"
- uses: jakebailey/pyright-action@v2

# test:
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v5
# with:
# python-version: '3.10'
# - run: pip install ".[dev]"
# - run: pytest
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["pypy3.10", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: pip install ".[dev]"
- run: pytest
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
3 changes: 1 addition & 2 deletions examples/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent import Agent
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.message import Message


@dataclass
class MessageType(Message):
class MessageType:
body: str
sender: str

Expand Down
8 changes: 4 additions & 4 deletions examples/round_robin_chat.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from dataclasses import dataclass
import random
import asyncio
import random
from dataclasses import dataclass
from typing import List

from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.application_components.single_threaded_agent_runtime import SingleThreadedAgentRuntime
from agnext.core.agent_runtime import AgentRuntime
from agnext.core.message import Message


# TODO: a runtime should be able to handle multiple types of messages
# TODO: allow request and response to be different message types
# should support this in handlers.
@dataclass
class GroupChatMessage(Message):
class GroupChatMessage:
body: str
sender: str
require_response: bool
Expand Down
13 changes: 6 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"openai>=1.3",
"pillow",
"aiohttp",
"typing-extensions"
]
dependencies = ["openai>=1.3", "pillow", "aiohttp", "typing-extensions"]

[project.optional-dependencies]
dev = ["ruff", "pyright", "mypy", "pytest", "types-Pillow"]
dev = ["ruff", "pyright", "mypy", "pytest", "pytest-asyncio", "types-Pillow"]

[tool.setuptools.package-data]
agnext = ["py.typed"]
Expand Down Expand Up @@ -61,3 +56,7 @@ include = ["src", "examples"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false

[tool.pytest.ini_options]
minversion = "6.0"
testpaths = ["tests"]
15 changes: 7 additions & 8 deletions src/agnext/agent_components/type_routed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from agnext.core.agent_runtime import AgentRuntime
from agnext.core.base_agent import BaseAgent
from agnext.core.cancellation_token import CancellationToken
from agnext.core.exceptions import CantHandleException

from ..core.message import Message

T = TypeVar("T", bound=Message)
T = TypeVar("T")


# NOTE: this works on concrete types and not inheritance
Expand All @@ -22,7 +21,7 @@ class TypeRoutedAgent(BaseAgent[T]):
def __init__(self, name: str, router: AgentRuntime[T]) -> None:
super().__init__(name, router)

self._handlers: Dict[Type[Any], Callable[[T], Awaitable[T]]] = {}
self._handlers: Dict[Type[Any], Callable[[T, CancellationToken], Awaitable[T]]] = {}

router.add_agent(self)

Expand All @@ -37,12 +36,12 @@ def __init__(self, name: str, router: AgentRuntime[T]) -> None:
def subscriptions(self) -> Sequence[Type[T]]:
return list(self._handlers.keys())

async def on_message(self, message: T) -> T:
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T:
handler = self._handlers.get(type(message))
if handler is not None:
return await handler(message)
return await handler(message, cancellation_token)
else:
return await self.on_unhandled_message(message)
return await self.on_unhandled_message(message, cancellation_token)

async def on_unhandled_message(self, message: T) -> T:
async def on_unhandled_message(self, message: T, cancellation_token: CancellationToken) -> T:
raise CantHandleException()
54 changes: 41 additions & 13 deletions src/agnext/application_components/single_threaded_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from dataclasses import dataclass
from typing import Awaitable, Dict, Generic, List, Set, Type, TypeVar

from agnext.core.cancellation_token import CancellationToken

from ..core.agent import Agent
from ..core.agent_runtime import AgentRuntime
from ..core.message import Message

T = TypeVar("T", bound=Message)
T = TypeVar("T")


@dataclass
Expand All @@ -17,6 +18,7 @@ class BroadcastMessageEnvelope(Generic[T]):

message: T
future: Future[List[T]]
cancellation_token: CancellationToken


@dataclass
Expand All @@ -27,6 +29,7 @@ class SendMessageEnvelope(Generic[T]):
message: T
destination: Agent[T]
future: Future[T]
cancellation_token: CancellationToken


@dataclass
Expand Down Expand Up @@ -64,17 +67,26 @@ def add_agent(self, agent: Agent[T]) -> None:
self._agents.add(agent)

# Returns the response of the message
def send_message(self, message: T, destination: Agent[T]) -> Future[T]:
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()

loop = asyncio.get_event_loop()
future: Future[T] = loop.create_future()

self._message_queue.append(SendMessageEnvelope(message, destination, future))
self._message_queue.append(SendMessageEnvelope(message, destination, future, cancellation_token))

return future

# Returns the response of all handling agents
def broadcast_message(self, message: T) -> Future[List[T]]:
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()

future: Future[List[T]] = asyncio.get_event_loop().create_future()
self._message_queue.append(BroadcastMessageEnvelope(message, future))
self._message_queue.append(BroadcastMessageEnvelope(message, future, cancellation_token))
return future

async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
Expand All @@ -83,16 +95,28 @@ async def _process_send(self, message_envelope: SendMessageEnvelope[T]) -> None:
message_envelope.future.set_exception(Exception("Recipient not found"))
return

response = await recipient.on_message(message_envelope.message)
try:
response = await recipient.on_message(
message_envelope.message, cancellation_token=message_envelope.cancellation_token
)
except BaseException as e:
message_envelope.future.set_exception(e)
return

self._message_queue.append(ResponseMessageEnvelope(response, message_envelope.future))

async def _process_broadcast(self, message_envelope: BroadcastMessageEnvelope[T]) -> None:
responses: List[Awaitable[T]] = []
for agent in self._per_type_subscribers.get(type(message_envelope.message), []):
future = agent.on_message(message_envelope.message)
future = agent.on_message(message_envelope.message, cancellation_token=message_envelope.cancellation_token)
responses.append(future)

all_responses = await asyncio.gather(*responses)
try:
all_responses = await asyncio.gather(*responses)
except BaseException as e:
message_envelope.future.set_exception(e)
return

self._message_queue.append(BroadcastResponseMessageEnvelope(all_responses, message_envelope.future))

async def _process_response(self, message_envelope: ResponseMessageEnvelope[T]) -> None:
Expand All @@ -110,10 +134,14 @@ async def process_next(self) -> None:
message_envelope = self._message_queue.pop(0)

match message_envelope:
case SendMessageEnvelope(message, destination, future):
asyncio.create_task(self._process_send(SendMessageEnvelope(message, destination, future)))
case BroadcastMessageEnvelope(message, future):
asyncio.create_task(self._process_broadcast(BroadcastMessageEnvelope(message, future)))
case SendMessageEnvelope(message, destination, future, cancellation_token):
asyncio.create_task(
self._process_send(SendMessageEnvelope(message, destination, future, cancellation_token))
)
case BroadcastMessageEnvelope(message, future, cancellation_token):
asyncio.create_task(
self._process_broadcast(BroadcastMessageEnvelope(message, future, cancellation_token))
)
case ResponseMessageEnvelope(message, future):
asyncio.create_task(self._process_response(ResponseMessageEnvelope(message, future)))
case BroadcastResponseMessageEnvelope(message, future):
Expand Down
6 changes: 3 additions & 3 deletions src/agnext/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Protocol, Sequence, Type, TypeVar

from .message import Message
from agnext.core.cancellation_token import CancellationToken

T = TypeVar("T", bound=Message)
T = TypeVar("T")


class Agent(Protocol[T]):
Expand All @@ -12,4 +12,4 @@ def name(self) -> str: ...
@property
def subscriptions(self) -> Sequence[Type[T]]: ...

async def on_message(self, message: T) -> T: ...
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...
11 changes: 6 additions & 5 deletions src/agnext/core/agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from typing import List, Protocol, TypeVar

from agnext.core.agent import Agent
from agnext.core.cancellation_token import CancellationToken

from .message import Message

T = TypeVar("T", bound=Message)
T = TypeVar("T")

# Undeliverable - error

Expand All @@ -14,7 +13,9 @@ class AgentRuntime(Protocol[T]):
def add_agent(self, agent: Agent[T]) -> None: ...

# Returns the response of the message
def send_message(self, message: T, destination: Agent[T]) -> Future[T]: ...
def send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]: ...

# Returns the response of all handling agents
def broadcast_message(self, message: T) -> Future[List[T]]: ...
def broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]: ...
28 changes: 19 additions & 9 deletions src/agnext/core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import List, Sequence, Type, TypeVar

from agnext.core.agent_runtime import AgentRuntime
from agnext.core.cancellation_token import CancellationToken

from .agent import Agent
from .message import Message

T = TypeVar("T", bound=Message)
T = TypeVar("T")


class BaseAgent(ABC, Agent[T]):
Expand All @@ -25,10 +25,20 @@ def subscriptions(self) -> Sequence[Type[T]]:
return []

@abstractmethod
async def on_message(self, message: T) -> T: ...

def _send_message(self, message: T, destination: Agent[T]) -> Future[T]:
return self._router.send_message(message, destination)

def _broadcast_message(self, message: T) -> Future[List[T]]:
return self._router.broadcast_message(message)
async def on_message(self, message: T, cancellation_token: CancellationToken) -> T: ...

def _send_message(
self, message: T, destination: Agent[T], cancellation_token: CancellationToken | None = None
) -> Future[T]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.send_message(message, destination, cancellation_token)
cancellation_token.link_future(future)
return future

def _broadcast_message(self, message: T, cancellation_token: CancellationToken | None = None) -> Future[List[T]]:
if cancellation_token is None:
cancellation_token = CancellationToken()
future = self._router.broadcast_message(message, cancellation_token)
cancellation_token.link_future(future)
return future
39 changes: 39 additions & 0 deletions src/agnext/core/cancellation_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import threading
from asyncio import Future
from typing import Any, Callable, List


class CancellationToken:
def __init__(self) -> None:
self._cancelled: bool = False
self._lock: threading.Lock = threading.Lock()
self._callbacks: List[Callable[[], None]] = []

def cancel(self) -> None:
with self._lock:
if not self._cancelled:
self._cancelled = True
for callback in self._callbacks:
callback()

def is_cancelled(self) -> bool:
with self._lock:
return self._cancelled

def add_callback(self, callback: Callable[[], None]) -> None:
with self._lock:
if self._cancelled:
callback()
else:
self._callbacks.append(callback)

def link_future(self, future: Future[Any]) -> None:
with self._lock:
if self._cancelled:
future.cancel()
else:

def _cancel() -> None:
future.cancel()

self._callbacks.append(_cancel)
6 changes: 0 additions & 6 deletions src/agnext/core/message.py

This file was deleted.

Loading

0 comments on commit 5afbadb

Please sign in to comment.