From 2d7bcba1c0c5537eedd32340c2ec738bb8dae22d Mon Sep 17 00:00:00 2001 From: Mateusz Czubak Date: Tue, 3 Jun 2025 12:54:58 +0200 Subject: [PATCH] fix(stream): don't block event loop in EventQueue --- src/a2a/server/events/event_queue.py | 29 +- src/a2a/server/tasks/task_updater.py | 28 +- tests/server/events/test_event_queue.py | 10 +- .../test_default_request_handler.py | 77 ++++ tests/server/tasks/test_task_updater.py | 387 ++++++++++-------- 5 files changed, 323 insertions(+), 208 deletions(-) create mode 100644 tests/server/request_handlers/test_default_request_handler.py diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 73972f37..92fc7597 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -17,6 +17,8 @@ Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent """Type alias for events that can be enqueued.""" +DEFAULT_MAX_QUEUE_SIZE = 1024 + @trace_class(kind=SpanKind.SERVER) class EventQueue: @@ -27,27 +29,38 @@ class EventQueue: to create child queues that receive the same events. """ - def __init__(self) -> None: + def __init__(self, max_queue_size=DEFAULT_MAX_QUEUE_SIZE) -> None: """Initializes the EventQueue.""" - self.queue: asyncio.Queue[Event] = asyncio.Queue() + + # Make sure the `asyncio.Queue` is bounded. + # If it's unbounded (maxsize=0), then `queue.put()` never needs to wait, + # and so the streaming won't work correctly. + if max_queue_size <= 0: + raise ValueError('max_queue_size must be greater than 0') + + self.queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=max_queue_size) self._children: list[EventQueue] = [] self._is_closed = False self._lock = asyncio.Lock() logger.debug('EventQueue initialized.') - def enqueue_event(self, event: Event): + async def enqueue_event(self, event: Event): """Enqueues an event to this queue and all its children. Args: event: The event object to enqueue. """ - if self._is_closed: - logger.warning('Queue is closed. Event will not be enqueued.') - return + async with self._lock: + if self._is_closed: + logger.warning('Queue is closed. Event will not be enqueued.') + return + logger.debug(f'Enqueuing event of type: {type(event)}') - self.queue.put_nowait(event) + + # Make sure to use put instead of put_nowait to avoid blocking the event loop. + await self.queue.put(event) for child in self._children: - child.enqueue_event(event) + await child.enqueue_event(event) async def dequeue_event(self, no_wait: bool = False) -> Event: """Dequeues an event from the queue. diff --git a/src/a2a/server/tasks/task_updater.py b/src/a2a/server/tasks/task_updater.py index 1869eddf..69464fba 100644 --- a/src/a2a/server/tasks/task_updater.py +++ b/src/a2a/server/tasks/task_updater.py @@ -34,7 +34,7 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str): self.task_id = task_id self.context_id = context_id - def update_status( + async def update_status( self, state: TaskState, message: Message | None = None, @@ -52,7 +52,7 @@ def update_status( current_timestamp = ( timestamp if timestamp else datetime.now(timezone.utc).isoformat() ) - self.event_queue.enqueue_event( + await self.event_queue.enqueue_event( TaskStatusUpdateEvent( taskId=self.task_id, contextId=self.context_id, @@ -65,7 +65,7 @@ def update_status( ) ) - def add_artifact( + async def add_artifact( self, parts: list[Part], artifact_id: str = str(uuid.uuid4()), @@ -82,7 +82,7 @@ def add_artifact( append: Optional boolean indicating if this chunk appends to a previous one. last_chunk: Optional boolean indicating if this is the last chunk. """ - self.event_queue.enqueue_event( + await self.event_queue.enqueue_event( TaskArtifactUpdateEvent( taskId=self.task_id, contextId=self.context_id, @@ -95,32 +95,32 @@ def add_artifact( ) ) - def complete(self, message: Message | None = None): + async def complete(self, message: Message | None = None): """Marks the task as completed and publishes a final status update.""" - self.update_status( + await self.update_status( TaskState.completed, message=message, final=True, ) - def failed(self, message: Message | None = None): + async def failed(self, message: Message | None = None): """Marks the task as failed and publishes a final status update.""" - self.update_status(TaskState.failed, message=message, final=True) + await self.update_status(TaskState.failed, message=message, final=True) - def reject(self, message: Message | None = None): + async def reject(self, message: Message | None = None): """Marks the task as rejected and publishes a final status update.""" - self.update_status(TaskState.rejected, message=message, final=True) + await self.update_status(TaskState.rejected, message=message, final=True) - def submit(self, message: Message | None = None): + async def submit(self, message: Message | None = None): """Marks the task as submitted and publishes a status update.""" - self.update_status( + await self.update_status( TaskState.submitted, message=message, ) - def start_work(self, message: Message | None = None): + async def start_work(self, message: Message | None = None): """Marks the task as working and publishes a status update.""" - self.update_status( + await self.update_status( TaskState.working, message=message, ) diff --git a/tests/server/events/test_event_queue.py b/tests/server/events/test_event_queue.py index 8a9c163e..111989fb 100644 --- a/tests/server/events/test_event_queue.py +++ b/tests/server/events/test_event_queue.py @@ -39,7 +39,7 @@ def event_queue() -> EventQueue: async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None: """Test that an event can be enqueued and dequeued.""" event = Message(**MESSAGE_PAYLOAD) - event_queue.enqueue_event(event) + await event_queue.enqueue_event(event) dequeued_event = await event_queue.dequeue_event() assert dequeued_event == event @@ -48,7 +48,7 @@ async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None: async def test_dequeue_event_no_wait(event_queue: EventQueue) -> None: """Test dequeue_event with no_wait=True.""" event = Task(**MINIMAL_TASK) - event_queue.enqueue_event(event) + await event_queue.enqueue_event(event) dequeued_event = await event_queue.dequeue_event(no_wait=True) assert dequeued_event == event @@ -71,7 +71,7 @@ async def test_dequeue_event_wait(event_queue: EventQueue) -> None: status=TaskStatus(state=TaskState.working), final=True, ) - event_queue.enqueue_event(event) + await event_queue.enqueue_event(event) dequeued_event = await event_queue.dequeue_event() assert dequeued_event == event @@ -84,7 +84,7 @@ async def test_task_done(event_queue: EventQueue) -> None: contextId='session-xyz', artifact=Artifact(artifactId='11', parts=[Part(TextPart(text='text'))]), ) - event_queue.enqueue_event(event) + await event_queue.enqueue_event(event) _ = await event_queue.dequeue_event() event_queue.task_done() @@ -99,6 +99,6 @@ async def test_enqueue_different_event_types( JSONRPCError(code=111, message='rpc error'), ] for event in events: - event_queue.enqueue_event(event) + await event_queue.enqueue_event(event) dequeued_event = await event_queue.dequeue_event() assert dequeued_event == event diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py new file mode 100644 index 00000000..b5b31c81 --- /dev/null +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -0,0 +1,77 @@ +import time + +import pytest + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore, TaskUpdater +from a2a.types import ( + Message, + MessageSendParams, + Part, + Role, + TaskState, + TextPart, +) + + +class DummyAgentExecutor(AgentExecutor): + async def execute(self, context: RequestContext, event_queue: EventQueue): + task_updater = TaskUpdater( + event_queue, context.task_id, context.context_id + ) + async for i in self._run(): + parts = [Part(root=TextPart(text=f'Event {i}'))] + try: + await task_updater.update_status( + TaskState.working, + message=task_updater.new_agent_message(parts), + ) + except RuntimeError: + # Stop processing when the event loop is closed + break + + async def _run(self): + for i in range(1_000_000): # Simulate a long-running stream + yield i + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + pass + + +@pytest.mark.asyncio +async def test_on_message_send_stream(): + request_handler = DefaultRequestHandler( + DummyAgentExecutor(), InMemoryTaskStore() + ) + message_params = MessageSendParams( + message=Message( + role=Role.user, + messageId='msg-123', + parts=[Part(root=TextPart(text='How are you?'))], + ), + ) + + async def consume_stream(): + events = [] + async for event in request_handler.on_message_send_stream( + message_params + ): + events.append(event) + if len(events) >= 3: + break # Stop after a few events + + return events + + # Consume first 3 events from the stream and measure time + start = time.perf_counter() + events = await consume_stream() + elapsed = time.perf_counter() - start + + # Assert we received events quickly + assert len(events) == 3 + assert elapsed < 0.5 + + texts = [p.root.text for e in events for p in e.status.message.parts] + assert texts == ['Event 0', 'Event 1', 'Event 2'] diff --git a/tests/server/tasks/test_task_updater.py b/tests/server/tasks/test_task_updater.py index af057090..40a867f1 100644 --- a/tests/server/tasks/test_task_updater.py +++ b/tests/server/tasks/test_task_updater.py @@ -1,6 +1,6 @@ import uuid -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -17,216 +17,241 @@ ) -class TestTaskUpdater: - @pytest.fixture - def event_queue(self): - """Create a mock event queue for testing.""" - return Mock(spec=EventQueue) +@pytest.fixture +def event_queue(): + """Create a mock event queue for testing.""" + return AsyncMock(spec=EventQueue) - @pytest.fixture - def task_updater(self, event_queue): - """Create a TaskUpdater instance for testing.""" - return TaskUpdater( - event_queue=event_queue, - task_id='test-task-id', - context_id='test-context-id', - ) - @pytest.fixture - def sample_message(self): - """Create a sample message for testing.""" - return Message( - role=Role.agent, - taskId='test-task-id', - contextId='test-context-id', - messageId='test-message-id', - parts=[Part(root=TextPart(text='Test message'))], - ) +@pytest.fixture +def task_updater(event_queue): + """Create a TaskUpdater instance for testing.""" + return TaskUpdater( + event_queue=event_queue, + task_id='test-task-id', + context_id='test-context-id', + ) + - @pytest.fixture - def sample_parts(self): - """Create sample parts for testing.""" - return [Part(root=TextPart(text='Test part'))] - - def test_init(self, event_queue): - """Test that TaskUpdater initializes correctly.""" - task_updater = TaskUpdater( - event_queue=event_queue, - task_id='test-task-id', - context_id='test-context-id', - ) +@pytest.fixture +def sample_message(): + """Create a sample message for testing.""" + return Message( + role=Role.agent, + taskId='test-task-id', + contextId='test-context-id', + messageId='test-message-id', + parts=[Part(root=TextPart(text='Test message'))], + ) - assert task_updater.event_queue == event_queue - assert task_updater.task_id == 'test-task-id' - assert task_updater.context_id == 'test-context-id' - def test_update_status_without_message(self, task_updater, event_queue): - """Test updating status without a message.""" - task_updater.update_status(TaskState.working) +@pytest.fixture +def sample_parts(): + """Create sample parts for testing.""" + return [Part(root=TextPart(text='Test part'))] - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] - assert isinstance(event, TaskStatusUpdateEvent) - assert event.taskId == 'test-task-id' - assert event.contextId == 'test-context-id' - assert event.final is False - assert event.status.state == TaskState.working - assert event.status.message is None +def test_init(event_queue): + """Test that TaskUpdater initializes correctly.""" + task_updater = TaskUpdater( + event_queue=event_queue, + task_id='test-task-id', + context_id='test-context-id', + ) + + assert task_updater.event_queue == event_queue + assert task_updater.task_id == 'test-task-id' + assert task_updater.context_id == 'test-context-id' + - def test_update_status_with_message( - self, task_updater, event_queue, sample_message - ): - """Test updating status with a message.""" - task_updater.update_status(TaskState.working, message=sample_message) +@pytest.mark.asyncio +async def test_update_status_without_message(task_updater, event_queue): + """Test updating status without a message.""" + await task_updater.update_status(TaskState.working) - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] - assert isinstance(event, TaskStatusUpdateEvent) - assert event.taskId == 'test-task-id' - assert event.contextId == 'test-context-id' - assert event.final is False - assert event.status.state == TaskState.working - assert event.status.message == sample_message + assert isinstance(event, TaskStatusUpdateEvent) + assert event.taskId == 'test-task-id' + assert event.contextId == 'test-context-id' + assert event.final is False + assert event.status.state == TaskState.working + assert event.status.message is None - def test_update_status_final(self, task_updater, event_queue): - """Test updating status with final=True.""" - task_updater.update_status(TaskState.completed, final=True) - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] +@pytest.mark.asyncio +async def test_update_status_with_message( + task_updater, event_queue, sample_message +): + """Test updating status with a message.""" + await task_updater.update_status(TaskState.working, message=sample_message) - assert isinstance(event, TaskStatusUpdateEvent) - assert event.final is True - assert event.status.state == TaskState.completed + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] - def test_add_artifact_with_custom_id_and_name( - self, task_updater, event_queue, sample_parts - ): - """Test adding an artifact with a custom ID and name.""" - task_updater.add_artifact( - parts=sample_parts, - artifact_id='custom-artifact-id', - name='Custom Artifact', - ) + assert isinstance(event, TaskStatusUpdateEvent) + assert event.taskId == 'test-task-id' + assert event.contextId == 'test-context-id' + assert event.final is False + assert event.status.state == TaskState.working + assert event.status.message == sample_message - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] - assert isinstance(event, TaskArtifactUpdateEvent) - assert event.artifact.artifactId == 'custom-artifact-id' - assert event.artifact.name == 'Custom Artifact' - assert event.artifact.parts == sample_parts +@pytest.mark.asyncio +async def test_update_status_final(task_updater, event_queue): + """Test updating status with final=True.""" + await task_updater.update_status(TaskState.completed, final=True) - def test_complete_without_message(self, task_updater, event_queue): - """Test marking a task as completed without a message.""" - task_updater.complete() + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] + assert isinstance(event, TaskStatusUpdateEvent) + assert event.final is True + assert event.status.state == TaskState.completed - assert isinstance(event, TaskStatusUpdateEvent) - assert event.status.state == TaskState.completed - assert event.final is True - assert event.status.message is None - def test_complete_with_message( - self, task_updater, event_queue, sample_message - ): - """Test marking a task as completed with a message.""" - task_updater.complete(message=sample_message) +@pytest.mark.asyncio +async def test_add_artifact_with_custom_id_and_name( + task_updater, event_queue, sample_parts +): + """Test adding an artifact with a custom ID and name.""" + await task_updater.add_artifact( + parts=sample_parts, + artifact_id='custom-artifact-id', + name='Custom Artifact', + ) - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskArtifactUpdateEvent) + assert event.artifact.artifactId == 'custom-artifact-id' + assert event.artifact.name == 'Custom Artifact' + assert event.artifact.parts == sample_parts - assert isinstance(event, TaskStatusUpdateEvent) - assert event.status.state == TaskState.completed - assert event.final is True - assert event.status.message == sample_message - def test_submit_without_message(self, task_updater, event_queue): - """Test marking a task as submitted without a message.""" - task_updater.submit() +@pytest.mark.asyncio +async def test_complete_without_message(task_updater, event_queue): + """Test marking a task as completed without a message.""" + await task_updater.complete() - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] - assert isinstance(event, TaskStatusUpdateEvent) - assert event.status.state == TaskState.submitted - assert event.final is False - assert event.status.message is None + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.completed + assert event.final is True + assert event.status.message is None - def test_submit_with_message( - self, task_updater, event_queue, sample_message - ): - """Test marking a task as submitted with a message.""" - task_updater.submit(message=sample_message) - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] +@pytest.mark.asyncio +async def test_complete_with_message( + task_updater, event_queue, sample_message +): + """Test marking a task as completed with a message.""" + await task_updater.complete(message=sample_message) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.completed + assert event.final is True + assert event.status.message == sample_message + + +@pytest.mark.asyncio +async def test_submit_without_message(task_updater, event_queue): + """Test marking a task as submitted without a message.""" + await task_updater.submit() + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.submitted + assert event.final is False + assert event.status.message is None + - assert isinstance(event, TaskStatusUpdateEvent) - assert event.status.state == TaskState.submitted - assert event.final is False - assert event.status.message == sample_message +@pytest.mark.asyncio +async def test_submit_with_message( + task_updater, event_queue, sample_message +): + """Test marking a task as submitted with a message.""" + await task_updater.submit(message=sample_message) - def test_start_work_without_message(self, task_updater, event_queue): - """Test marking a task as working without a message.""" - task_updater.start_work() + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.submitted + assert event.final is False + assert event.status.message == sample_message - assert isinstance(event, TaskStatusUpdateEvent) - assert event.status.state == TaskState.working - assert event.final is False - assert event.status.message is None - def test_start_work_with_message( - self, task_updater, event_queue, sample_message +@pytest.mark.asyncio +async def test_start_work_without_message(task_updater, event_queue): + """Test marking a task as working without a message.""" + await task_updater.start_work() + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.working + assert event.final is False + assert event.status.message is None + + +@pytest.mark.asyncio +async def test_start_work_with_message( + task_updater, event_queue, sample_message +): + """Test marking a task as working with a message.""" + await task_updater.start_work(message=sample_message) + + event_queue.enqueue_event.assert_called_once() + event = event_queue.enqueue_event.call_args[0][0] + + assert isinstance(event, TaskStatusUpdateEvent) + assert event.status.state == TaskState.working + assert event.final is False + assert event.status.message == sample_message + + +def test_new_agent_message(task_updater, sample_parts): + """Test creating a new agent message.""" + with patch( + 'uuid.uuid4', + return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), ): - """Test marking a task as working with a message.""" - task_updater.start_work(message=sample_message) - - event_queue.enqueue_event.assert_called_once() - event = event_queue.enqueue_event.call_args[0][0] - - assert isinstance(event, TaskStatusUpdateEvent) - assert event.status.state == TaskState.working - assert event.final is False - assert event.status.message == sample_message - - def test_new_agent_message(self, task_updater, sample_parts): - """Test creating a new agent message.""" - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = task_updater.new_agent_message(parts=sample_parts) - - assert message.role == Role.agent - assert message.taskId == 'test-task-id' - assert message.contextId == 'test-context-id' - assert message.messageId == '12345678-1234-5678-1234-567812345678' - assert message.parts == sample_parts - assert message.metadata is None - - def test_new_agent_message_with_metadata(self, task_updater, sample_parts): - """Test creating a new agent message with metadata and final=True.""" - metadata = {'key': 'value'} - - with patch( - 'uuid.uuid4', - return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), - ): - message = task_updater.new_agent_message( - parts=sample_parts, metadata=metadata - ) - - assert message.role == Role.agent - assert message.taskId == 'test-task-id' - assert message.contextId == 'test-context-id' - assert message.messageId == '12345678-1234-5678-1234-567812345678' - assert message.parts == sample_parts - assert message.metadata == metadata + message = task_updater.new_agent_message(parts=sample_parts) + + assert message.role == Role.agent + assert message.taskId == 'test-task-id' + assert message.contextId == 'test-context-id' + assert message.messageId == '12345678-1234-5678-1234-567812345678' + assert message.parts == sample_parts + assert message.metadata is None + + +def test_new_agent_message_with_metadata(task_updater, sample_parts): + """Test creating a new agent message with metadata and final=True.""" + metadata = {'key': 'value'} + + with patch( + 'uuid.uuid4', + return_value=uuid.UUID('12345678-1234-5678-1234-567812345678'), + ): + message = task_updater.new_agent_message( + parts=sample_parts, metadata=metadata + ) + + assert message.role == Role.agent + assert message.taskId == 'test-task-id' + assert message.contextId == 'test-context-id' + assert message.messageId == '12345678-1234-5678-1234-567812345678' + assert message.parts == sample_parts + assert message.metadata == metadata