Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions src/a2a/server/events/event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
"""Type alias for events that can be enqueued."""

DEFAULT_MAX_QUEUE_SIZE = 1024


@trace_class(kind=SpanKind.SERVER)
class EventQueue:
Expand All @@ -27,27 +29,38 @@ class EventQueue:
to create child queues that receive the same events.
"""

def __init__(self) -> None:
def __init__(self, max_queue_size=DEFAULT_MAX_QUEUE_SIZE) -> None:
"""Initializes the EventQueue."""
self.queue: asyncio.Queue[Event] = asyncio.Queue()

# Make sure the `asyncio.Queue` is bounded.
# If it's unbounded (maxsize=0), then `queue.put()` never needs to wait,
# and so the streaming won't work correctly.
if max_queue_size <= 0:
raise ValueError('max_queue_size must be greater than 0')

self.queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=max_queue_size)
self._children: list[EventQueue] = []
self._is_closed = False
self._lock = asyncio.Lock()
logger.debug('EventQueue initialized.')

def enqueue_event(self, event: Event):
async def enqueue_event(self, event: Event):
"""Enqueues an event to this queue and all its children.

Args:
event: The event object to enqueue.
"""
if self._is_closed:
logger.warning('Queue is closed. Event will not be enqueued.')
return
async with self._lock:
if self._is_closed:
logger.warning('Queue is closed. Event will not be enqueued.')
return

logger.debug(f'Enqueuing event of type: {type(event)}')
self.queue.put_nowait(event)

# Make sure to use put instead of put_nowait to avoid blocking the event loop.
await self.queue.put(event)
for child in self._children:
child.enqueue_event(event)
await child.enqueue_event(event)

async def dequeue_event(self, no_wait: bool = False) -> Event:
"""Dequeues an event from the queue.
Expand Down
28 changes: 14 additions & 14 deletions src/a2a/server/tasks/task_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
self.task_id = task_id
self.context_id = context_id

def update_status(
async def update_status(
self,
state: TaskState,
message: Message | None = None,
Expand All @@ -52,7 +52,7 @@ def update_status(
current_timestamp = (
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
)
self.event_queue.enqueue_event(
await self.event_queue.enqueue_event(
TaskStatusUpdateEvent(
taskId=self.task_id,
contextId=self.context_id,
Expand All @@ -65,7 +65,7 @@ def update_status(
)
)

def add_artifact(
async def add_artifact(
self,
parts: list[Part],
artifact_id: str = str(uuid.uuid4()),
Expand All @@ -82,7 +82,7 @@ def add_artifact(
append: Optional boolean indicating if this chunk appends to a previous one.
last_chunk: Optional boolean indicating if this is the last chunk.
"""
self.event_queue.enqueue_event(
await self.event_queue.enqueue_event(
TaskArtifactUpdateEvent(
taskId=self.task_id,
contextId=self.context_id,
Expand All @@ -95,32 +95,32 @@ def add_artifact(
)
)

def complete(self, message: Message | None = None):
async def complete(self, message: Message | None = None):
"""Marks the task as completed and publishes a final status update."""
self.update_status(
await self.update_status(
TaskState.completed,
message=message,
final=True,
)

def failed(self, message: Message | None = None):
async def failed(self, message: Message | None = None):
"""Marks the task as failed and publishes a final status update."""
self.update_status(TaskState.failed, message=message, final=True)
await self.update_status(TaskState.failed, message=message, final=True)

def reject(self, message: Message | None = None):
async def reject(self, message: Message | None = None):
"""Marks the task as rejected and publishes a final status update."""
self.update_status(TaskState.rejected, message=message, final=True)
await self.update_status(TaskState.rejected, message=message, final=True)

def submit(self, message: Message | None = None):
async def submit(self, message: Message | None = None):
"""Marks the task as submitted and publishes a status update."""
self.update_status(
await self.update_status(
TaskState.submitted,
message=message,
)

def start_work(self, message: Message | None = None):
async def start_work(self, message: Message | None = None):
"""Marks the task as working and publishes a status update."""
self.update_status(
await self.update_status(
TaskState.working,
message=message,
)
Expand Down
10 changes: 5 additions & 5 deletions tests/server/events/test_event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def event_queue() -> EventQueue:
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
"""Test that an event can be enqueued and dequeued."""
event = Message(**MESSAGE_PAYLOAD)
event_queue.enqueue_event(event)
await event_queue.enqueue_event(event)
dequeued_event = await event_queue.dequeue_event()
assert dequeued_event == event

Expand All @@ -48,7 +48,7 @@ async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:
async def test_dequeue_event_no_wait(event_queue: EventQueue) -> None:
"""Test dequeue_event with no_wait=True."""
event = Task(**MINIMAL_TASK)
event_queue.enqueue_event(event)
await event_queue.enqueue_event(event)
dequeued_event = await event_queue.dequeue_event(no_wait=True)
assert dequeued_event == event

Expand All @@ -71,7 +71,7 @@ async def test_dequeue_event_wait(event_queue: EventQueue) -> None:
status=TaskStatus(state=TaskState.working),
final=True,
)
event_queue.enqueue_event(event)
await event_queue.enqueue_event(event)
dequeued_event = await event_queue.dequeue_event()
assert dequeued_event == event

Expand All @@ -84,7 +84,7 @@ async def test_task_done(event_queue: EventQueue) -> None:
contextId='session-xyz',
artifact=Artifact(artifactId='11', parts=[Part(TextPart(text='text'))]),
)
event_queue.enqueue_event(event)
await event_queue.enqueue_event(event)
_ = await event_queue.dequeue_event()
event_queue.task_done()

Expand All @@ -99,6 +99,6 @@ async def test_enqueue_different_event_types(
JSONRPCError(code=111, message='rpc error'),
]
for event in events:
event_queue.enqueue_event(event)
await event_queue.enqueue_event(event)
dequeued_event = await event_queue.dequeue_event()
assert dequeued_event == event
77 changes: 77 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import time

import pytest

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore, TaskUpdater
from a2a.types import (
Message,
MessageSendParams,
Part,
Role,
TaskState,
TextPart,
)


class DummyAgentExecutor(AgentExecutor):
async def execute(self, context: RequestContext, event_queue: EventQueue):
task_updater = TaskUpdater(
event_queue, context.task_id, context.context_id
)
async for i in self._run():
parts = [Part(root=TextPart(text=f'Event {i}'))]
try:
await task_updater.update_status(
TaskState.working,
message=task_updater.new_agent_message(parts),
)
except RuntimeError:
# Stop processing when the event loop is closed
break

async def _run(self):
for i in range(1_000_000): # Simulate a long-running stream
yield i

async def cancel(self, context: RequestContext, event_queue: EventQueue):
pass


@pytest.mark.asyncio
async def test_on_message_send_stream():
request_handler = DefaultRequestHandler(
DummyAgentExecutor(), InMemoryTaskStore()
)
message_params = MessageSendParams(
message=Message(
role=Role.user,
messageId='msg-123',
parts=[Part(root=TextPart(text='How are you?'))],
),
)

async def consume_stream():
events = []
async for event in request_handler.on_message_send_stream(
message_params
):
events.append(event)
if len(events) >= 3:
break # Stop after a few events

return events

# Consume first 3 events from the stream and measure time
start = time.perf_counter()
events = await consume_stream()
elapsed = time.perf_counter() - start

# Assert we received events quickly
assert len(events) == 3
assert elapsed < 0.5

texts = [p.root.text for e in events for p in e.status.message.parts]
assert texts == ['Event 0', 'Event 1', 'Event 2']
Loading