diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java index 33025ed5e..3a24f5145 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java @@ -10,6 +10,8 @@ import io.a2a.jsonrpc.common.json.JsonProcessingException; import io.a2a.spec.A2AError; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; @@ -64,11 +66,23 @@ private void handleMessage(String message, @Nullable Future future) { StreamingEventKind event = ProtoUtils.FromProto.streamingEventKind(response); eventHandler.accept(event); - if (event instanceof TaskStatusUpdateEvent && ((TaskStatusUpdateEvent) event).isFinal()) { - if (future != null) { - future.cancel(true); // close SSE channel + + // Client-side auto-close on final events to prevent connection leaks + // Handles both TaskStatusUpdateEvent and Task objects with final states + // This covers late subscriptions to completed tasks and ensures no connection leaks + boolean shouldClose = false; + if (event instanceof TaskStatusUpdateEvent tue && tue.isFinal()) { + shouldClose = true; + } else if (event instanceof Task task) { + TaskState state = task.status().state(); + if (state.isFinal()) { + shouldClose = true; } } + + if (shouldClose && future != null) { + future.cancel(true); // close SSE channel + } } catch (A2AError error) { if (errorHandler != null) { errorHandler.accept(error); diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java index ec74d2fbc..85e604da3 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java @@ -10,6 +10,9 @@ import io.a2a.grpc.StreamResponse; import io.a2a.grpc.utils.ProtoUtils; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; public class RestSSEEventListener { @@ -29,7 +32,7 @@ public void onMessage(String message, @Nullable Future completableFuture) log.fine("Streaming message received: " + message); io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); JsonFormat.parser().merge(message, builder); - handleMessage(builder.build()); + handleMessage(builder.build(), completableFuture); } catch (InvalidProtocolBufferException e) { errorHandler.accept(RestErrorMapper.mapRestError(message, 500)); } @@ -44,7 +47,7 @@ public void onError(Throwable throwable, @Nullable Future future) { } } - private void handleMessage(StreamResponse response) { + private void handleMessage(StreamResponse response, @Nullable Future future) { StreamingEventKind event; switch (response.getPayloadCase()) { case MESSAGE -> @@ -62,6 +65,23 @@ private void handleMessage(StreamResponse response) { } } eventHandler.accept(event); + + // Client-side auto-close on final events to prevent connection leaks + // Handles both TaskStatusUpdateEvent and Task objects with final states + // This covers late subscriptions to completed tasks and ensures no connection leaks + boolean shouldClose = false; + if (event instanceof TaskStatusUpdateEvent tue && tue.isFinal()) { + shouldClose = true; + } else if (event instanceof Task task) { + TaskState state = task.status().state(); + if (state.isFinal()) { + shouldClose = true; + } + } + + if (shouldClose && future != null) { + future.cancel(true); // close SSE channel + } } } diff --git a/examples/cloud-deployment/scripts/deploy.sh b/examples/cloud-deployment/scripts/deploy.sh index e267f3302..fff2a6061 100755 --- a/examples/cloud-deployment/scripts/deploy.sh +++ b/examples/cloud-deployment/scripts/deploy.sh @@ -212,6 +212,22 @@ echo "" echo "Deploying PostgreSQL..." kubectl apply -f ../k8s/01-postgres.yaml echo "Waiting for PostgreSQL to be ready..." + +# Wait for pod to be created (StatefulSet takes time to create pod) +for i in {1..30}; do + if kubectl get pod -l app=postgres -n a2a-demo 2>/dev/null | grep -q postgres; then + echo "PostgreSQL pod found, waiting for ready state..." + break + fi + if [ $i -eq 30 ]; then + echo -e "${RED}ERROR: PostgreSQL pod not created after 30 seconds${NC}" + kubectl get statefulset -n a2a-demo + exit 1 + fi + sleep 1 +done + +# Now wait for pod to be ready kubectl wait --for=condition=Ready pod -l app=postgres -n a2a-demo --timeout=120s echo -e "${GREEN}✓ PostgreSQL deployed${NC}" diff --git a/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java b/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java index 8c5f59348..0c35bad7a 100644 --- a/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java +++ b/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java @@ -5,22 +5,28 @@ * This event is fired AFTER the database transaction commits, making it safe for downstream * components to assume the task is durably stored. * - *

Used by the replicated queue manager to send poison pill events after ensuring - * the final task state is committed to the database, eliminating race conditions. + *

Used by the replicated queue manager to send the final task state before the poison pill, + * ensuring correct event ordering across instances and eliminating race conditions. */ public class TaskFinalizedEvent { private final String taskId; + private final Object task; // Task type from io.a2a.spec - using Object to avoid dependency - public TaskFinalizedEvent(String taskId) { + public TaskFinalizedEvent(String taskId, Object task) { this.taskId = taskId; + this.task = task; } public String getTaskId() { return taskId; } + public Object getTask() { + return task; + } + @Override public String toString() { - return "TaskFinalizedEvent{taskId='" + taskId + "'}"; + return "TaskFinalizedEvent{taskId='" + taskId + "', task=" + task + "}"; } } diff --git a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java index 36245e277..5049bc9a4 100644 --- a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java +++ b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java @@ -164,4 +164,5 @@ public void deleteInfo(String taskId, String configId) { taskId, configId); } } + } diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java index 87c10fb4e..206e07f03 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java @@ -149,6 +149,16 @@ public void setClosedEvent(boolean closedEvent) { } } + /** + * Check if this event is a Task event. + * Task events should always be processed even for inactive tasks, + * as they carry the final task state. + * @return true if this is a Task event + */ + public boolean isTaskEvent() { + return event instanceof io.a2a.spec.Task; + } + @Override public String toString() { return "ReplicatedEventQueueItem{" + diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java index 586ab11a7..44dfbe427 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java @@ -13,8 +13,10 @@ import io.a2a.server.events.EventQueueFactory; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; import io.a2a.server.events.QueueManager; import io.a2a.server.tasks.TaskStateProvider; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,10 +47,12 @@ protected ReplicatedQueueManager() { } @Inject - public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, TaskStateProvider taskStateProvider) { + public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, + TaskStateProvider taskStateProvider, + MainEventBus mainEventBus) { this.replicationStrategy = replicationStrategy; this.taskStateProvider = taskStateProvider; - this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider); + this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider, mainEventBus); } @@ -77,8 +81,7 @@ public void close(String taskId) { @Override public EventQueue createOrTap(String taskId) { - EventQueue queue = delegate.createOrTap(taskId); - return queue; + return delegate.createOrTap(taskId); } @Override @@ -87,9 +90,11 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep } public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent) { - // Check if task is still active before processing replicated event (unless it's a QueueClosedEvent) - // QueueClosedEvent should always be processed to terminate streams, even for inactive tasks + // Check if task is still active before processing replicated event + // Always allow QueueClosedEvent and Task events (they carry final state) + // Skip other event types for inactive tasks to prevent queue creation for expired tasks if (!replicatedEvent.isClosedEvent() + && !replicatedEvent.isTaskEvent() && !taskStateProvider.isTaskActive(replicatedEvent.getTaskId())) { // Task is no longer active - skip processing this replicated event // This prevents creating queues for tasks that have been finalized beyond the grace period @@ -97,38 +102,81 @@ public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent return; } - // Get or create a ChildQueue for this task (creates MainQueue if it doesn't exist) - EventQueue childQueue = delegate.createOrTap(replicatedEvent.getTaskId()); - + // Get the MainQueue to enqueue the replicated event item + // We must use enqueueItem (not enqueueEvent) to preserve the isReplicated() flag + // and avoid triggering the replication hook again (which would cause a replication loop) + // + // IMPORTANT: We must NOT create a ChildQueue here! Creating and immediately closing + // a ChildQueue means there are zero children when MainEventBusProcessor distributes + // the event. Existing ChildQueues (from active client subscriptions) will receive + // the event when MainEventBusProcessor distributes it to all children. + // + // If MainQueue doesn't exist, create it. This handles late-arriving replicated events + // for tasks that were created on another instance. + EventQueue childQueue = null; // Track ChildQueue we might create + EventQueue mainQueue = delegate.get(replicatedEvent.getTaskId()); try { - // Get the MainQueue to enqueue the replicated event item - // We must use enqueueItem (not enqueueEvent) to preserve the isReplicated() flag - // and avoid triggering the replication hook again (which would cause a replication loop) - EventQueue mainQueue = delegate.get(replicatedEvent.getTaskId()); + if (mainQueue == null) { + LOGGER.debug("Creating MainQueue for replicated event on task {}", replicatedEvent.getTaskId()); + childQueue = delegate.createOrTap(replicatedEvent.getTaskId()); // Creates MainQueue + returns ChildQueue + mainQueue = delegate.get(replicatedEvent.getTaskId()); // Get MainQueue from map + } + if (mainQueue != null) { mainQueue.enqueueItem(replicatedEvent); } else { - LOGGER.warn("MainQueue not found for task {}, cannot enqueue replicated event. This may happen if the queue was already cleaned up.", - replicatedEvent.getTaskId()); + LOGGER.warn( + "MainQueue not found for task {}, cannot enqueue replicated event. This may happen if the queue was already cleaned up.", + replicatedEvent.getTaskId()); } } finally { - // Close the temporary ChildQueue to prevent leaks - // The MainQueue remains open for other consumers - childQueue.close(); + if (childQueue != null) { + try { + childQueue.close(); // Close the ChildQueue we created (not MainQueue!) + } catch (Exception ignore) { + // The close is safe, but print a stacktrace just in case + if (LOGGER.isDebugEnabled()) { + ignore.printStackTrace(); + } + } + } } } /** * Observes task finalization events fired AFTER database transaction commits. - * This guarantees the task's final state is durably stored before sending the poison pill. + * This guarantees the task's final state is durably stored before replication. * - * @param event the task finalized event containing the task ID + * Sends TaskStatusUpdateEvent (not full Task) FIRST, then the poison pill (QueueClosedEvent), + * ensuring correct event ordering across instances and eliminating race conditions where + * the poison pill arrives before the final task state. + * + * IMPORTANT: We send TaskStatusUpdateEvent instead of full Task to maintain consistency + * with local event distribution. Clients expect TaskStatusUpdateEvent for status changes, + * and sending the full Task causes issues in remote instances where clients don't handle + * bare Task objects the same way they handle TaskStatusUpdateEvent. + * + * @param event the task finalized event containing the task ID and final Task */ public void onTaskFinalized(@Observes(during = TransactionPhase.AFTER_SUCCESS) TaskFinalizedEvent event) { String taskId = event.getTaskId(); - LOGGER.debug("Task {} finalized - sending poison pill (QueueClosedEvent) after transaction commit", taskId); + io.a2a.spec.Task finalTask = (io.a2a.spec.Task) event.getTask(); // Cast from Object + + LOGGER.debug("Task {} finalized - sending TaskStatusUpdateEvent then poison pill (QueueClosedEvent) after transaction commit", taskId); + + // Convert final Task to TaskStatusUpdateEvent to match local event distribution + // This ensures remote instances receive the same event type as local instances + io.a2a.spec.TaskStatusUpdateEvent finalStatusEvent = io.a2a.spec.TaskStatusUpdateEvent.builder() + .taskId(taskId) + .contextId(finalTask.contextId()) + .status(finalTask.status()) + .isFinal(true) + .build(); + + // Send TaskStatusUpdateEvent FIRST to ensure it arrives before poison pill + replicationStrategy.send(taskId, finalStatusEvent); - // Send poison pill directly via replication strategy + // Then send poison pill // The transaction has committed, so the final state is guaranteed to be in the database io.a2a.server.events.QueueClosedEvent closedEvent = new io.a2a.server.events.QueueClosedEvent(taskId); replicationStrategy.send(taskId, closedEvent); @@ -152,12 +200,11 @@ public EventQueue.EventQueueBuilder builder(String taskId) { // which sends the QueueClosedEvent after the database transaction commits. // This ensures proper ordering and transactional guarantees. - // Return the builder with callbacks - return delegate.getEventQueueBuilder(taskId) - .taskId(taskId) - .hook(new ReplicationHook(taskId)) - .addOnCloseCallback(delegate.getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Call createBaseEventQueueBuilder() directly to avoid infinite recursion + // (getEventQueueBuilder() would delegate back to this factory, creating a loop) + // The base builder already includes: taskId, cleanup callback, taskStateProvider, mainEventBus + return delegate.createBaseEventQueueBuilder(taskId) + .hook(new ReplicationHook(taskId)); } } diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java index 43571cd30..14b4c1f51 100644 --- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java @@ -22,12 +22,19 @@ import io.a2a.server.events.EventQueueClosedException; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueTestHelper; +import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.events.QueueClosedEvent; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.Event; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -35,10 +42,27 @@ class ReplicatedQueueManagerTest { private ReplicatedQueueManager queueManager; private StreamingEventKind testEvent; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach void setUp() { - queueManager = new ReplicatedQueueManager(new NoOpReplicationStrategy(), new MockTaskStateProvider(true)); + // Create MainEventBus first + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + + // Create QueueManager before MainEventBusProcessor (processor needs it as parameter) + queueManager = new ReplicatedQueueManager( + new NoOpReplicationStrategy(), + new MockTaskStateProvider(true), + mainEventBus + ); + + // Create MainEventBusProcessor with QueueManager + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + testEvent = TaskStatusUpdateEvent.builder() .taskId("test-task") .contextId("test-context") @@ -47,25 +71,82 @@ void setUp() { .build(); } + /** + * Helper to create a test event with the specified taskId. + * This ensures taskId consistency between queue creation and event creation. + */ + private TaskStatusUpdateEvent createEventForTask(String taskId) { + return TaskStatusUpdateEvent.builder() + .taskId(taskId) + .contextId("test-context") + .status(new TaskStatus(TaskState.SUBMITTED)) + .isFinal(false) + .build(); + } + + @AfterEach + void tearDown() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + mainEventBusProcessor = null; + mainEventBus = null; + queueManager = null; + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test void testReplicationStrategyTriggeredOnNormalEnqueue() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-1"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process the event and trigger replication + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(1, strategy.getCallCount()); assertEquals(taskId, strategy.getLastTaskId()); - assertEquals(testEvent, strategy.getLastEvent()); + assertEquals(event, strategy.getLastEvent()); } @Test void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-2"; EventQueue queue = queueManager.createOrTap(taskId); @@ -79,17 +160,19 @@ void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedEx @Test void testReplicationStrategyWithCountingImplementation() throws InterruptedException { CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-3"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process each event + waitForEventProcessing(() -> queue.enqueueEvent(event)); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(2, countingStrategy.getCallCount()); assertEquals(taskId, countingStrategy.getLastTaskId()); - assertEquals(testEvent, countingStrategy.getLastEvent()); + assertEquals(event, countingStrategy.getLastEvent()); ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); queueManager.onReplicatedEvent(replicatedEvent); @@ -100,46 +183,45 @@ void testReplicationStrategyWithCountingImplementation() throws InterruptedExcep @Test void testReplicatedEventDeliveredToCorrectQueue() throws InterruptedException { String taskId = "test-task-4"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId EventQueue queue = queueManager.createOrTap(taskId); - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - queueManager.onReplicatedEvent(replicatedEvent); + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent); + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedEvent)); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent); } @Test void testReplicatedEventCreatesQueueIfNeeded() throws InterruptedException { String taskId = "non-existent-task"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId // Verify no queue exists initially assertNull(queueManager.get(taskId)); - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - - // Process the replicated event - assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent)); - - // Verify that a queue was created and the event was enqueued - EventQueue queue = queueManager.get(taskId); - assertNotNull(queue, "Queue should be created when processing replicated event for non-existent task"); - - // Verify the event was enqueued by dequeuing it - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent, "The replicated event should be enqueued in the newly created queue"); + // Create a ChildQueue BEFORE processing the replicated event + // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event + EventQueue childQueue = queueManager.createOrTap(taskId); + assertNotNull(childQueue, "ChildQueue should be created"); + + // Verify MainQueue was created + EventQueue mainQueue = queueManager.get(taskId); + assertNotNull(mainQueue, "MainQueue should exist after createOrTap"); + + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); + + // Process the replicated event and wait for distribution + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(childQueue, () -> { + assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent)); + }); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent, "The replicated event should be enqueued in the newly created queue"); } @Test @@ -170,17 +252,18 @@ void testBasicQueueManagerFunctionality() throws InterruptedException { void testQueueToTaskIdMappingMaintained() throws InterruptedException { String taskId = "test-task-6"; CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); + TaskStatusUpdateEvent event = createEventForTask(taskId); EventQueue queue = queueManager.createOrTap(taskId); - queue.enqueueEvent(testEvent); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); queueManager.close(taskId); // Task is active, so NO poison pill is sent EventQueue newQueue = queueManager.createOrTap(taskId); - newQueue.enqueueEvent(testEvent); + waitForEventProcessing(() -> newQueue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); // 2 replication calls: 1 testEvent, 1 testEvent (no QueueClosedEvent because task is active) @@ -217,16 +300,43 @@ void testReplicatedEventJsonSerialization() throws Exception { @Test void testParallelReplicationBehavior() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "parallel-test-task"; EventQueue queue = queueManager.createOrTap(taskId); int numThreads = 10; int eventsPerThread = 5; + int expectedEventCount = (numThreads / 2) * eventsPerThread; // Only normal enqueues + int totalEventCount = numThreads * eventsPerThread; // All events (normal + replicated) ExecutorService executor = Executors.newFixedThreadPool(numThreads); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(numThreads); + + // Use CyclicBarrier for better thread synchronization + // This ensures all threads start their work at approximately the same time + java.util.concurrent.CyclicBarrier barrier = new java.util.concurrent.CyclicBarrier(numThreads); + + // Track processed events for better diagnostics on failure + java.util.concurrent.CopyOnWriteArrayList processedEvents = + new java.util.concurrent.CopyOnWriteArrayList<>(); + + // Set up callback to wait for ALL events to be processed by MainEventBusProcessor + // Must wait for all 50 events (25 normal + 25 replicated) to ensure all normal events + // have triggered replication before we check the count + CountDownLatch processingLatch = new CountDownLatch(totalEventCount); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String tid, io.a2a.spec.Event event) { + processedEvents.add(event); + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String tid) { + // Not needed for this test + } + }); // Launch threads that will enqueue events normally (should trigger replication) for (int i = 0; i < numThreads / 2; i++) { @@ -234,18 +344,20 @@ void testParallelReplicationBehavior() throws InterruptedException { executor.submit(() -> { try { startLatch.await(); + barrier.await(); // Synchronize thread starts for better interleaving for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("normal-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.WORKING)) .isFinal(false) .build(); queue.enqueueEvent(event); - Thread.sleep(1); // Small delay to interleave operations } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + } catch (java.util.concurrent.BrokenBarrierException e) { + throw new RuntimeException("Barrier broken", e); } finally { doneLatch.countDown(); } @@ -258,19 +370,21 @@ void testParallelReplicationBehavior() throws InterruptedException { executor.submit(() -> { try { startLatch.await(); + barrier.await(); // Synchronize thread starts for better interleaving for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("replicated-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.COMPLETED)) .isFinal(true) .build(); ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, event); queueManager.onReplicatedEvent(replicatedEvent); - Thread.sleep(1); // Small delay to interleave operations } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + } catch (java.util.concurrent.BrokenBarrierException e) { + throw new RuntimeException("Barrier broken", e); } finally { doneLatch.countDown(); } @@ -280,24 +394,43 @@ void testParallelReplicationBehavior() throws InterruptedException { // Start all threads simultaneously startLatch.countDown(); - // Wait for all threads to complete - assertTrue(doneLatch.await(10, TimeUnit.SECONDS), "All threads should complete within 10 seconds"); + // Wait for all threads to complete with explicit timeout + assertTrue(doneLatch.await(10, TimeUnit.SECONDS), + "All " + numThreads + " threads should complete within 10 seconds"); executor.shutdown(); - assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), "Executor should shutdown within 5 seconds"); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), + "Executor should shutdown within 5 seconds"); + + // Wait for MainEventBusProcessor to process all events + try { + boolean allProcessed = processingLatch.await(10, TimeUnit.SECONDS); + assertTrue(allProcessed, + String.format("MainEventBusProcessor should have processed all %d events within timeout. " + + "Processed: %d, Remaining: %d", + totalEventCount, processedEvents.size(), processingLatch.getCount())); + } finally { + mainEventBusProcessor.setCallback(null); + queue.close(true, true); + } + + // Verify we processed the expected number of events + assertEquals(totalEventCount, processedEvents.size(), + "Should have processed exactly " + totalEventCount + " events (normal + replicated)"); // Only the normal enqueue operations should have triggered replication // numThreads/2 threads * eventsPerThread events each = total expected replication calls int expectedReplicationCalls = (numThreads / 2) * eventsPerThread; assertEquals(expectedReplicationCalls, strategy.getCallCount(), - "Only normal enqueue operations should trigger replication, not replicated events"); + String.format("Only normal enqueue operations should trigger replication, not replicated events. " + + "Expected: %d, Actual: %d", expectedReplicationCalls, strategy.getCallCount())); } @Test void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException { // Create a task state provider that returns false (task is inactive) MockTaskStateProvider stateProvider = new MockTaskStateProvider(false); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "inactive-task"; @@ -316,30 +449,32 @@ void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException { void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException { // Create a task state provider that returns true (task is active) MockTaskStateProvider stateProvider = new MockTaskStateProvider(true); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "active-task"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId // Verify no queue exists initially assertNull(queueManager.get(taskId)); - // Process a replicated event for an active task - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - queueManager.onReplicatedEvent(replicatedEvent); + // Create a ChildQueue BEFORE processing the replicated event + // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event + EventQueue childQueue = queueManager.createOrTap(taskId); + assertNotNull(childQueue, "ChildQueue should be created"); - // Queue should be created and event should be enqueued - EventQueue queue = queueManager.get(taskId); - assertNotNull(queue, "Queue should be created for active task"); + // Verify MainQueue was created + EventQueue mainQueue = queueManager.get(taskId); + assertNotNull(mainQueue, "MainQueue should exist after createOrTap"); - // Verify the event was enqueued - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent, "Event should be enqueued for active task"); + // Process a replicated event for an active task + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); + + // Verify the event was enqueued and distributed to our ChildQueue + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(childQueue, () -> queueManager.onReplicatedEvent(replicatedEvent)); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent, "Event should be enqueued for active task"); } @@ -347,7 +482,7 @@ void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException { void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws InterruptedException { // Create a task state provider that returns true initially MockTaskStateProvider stateProvider = new MockTaskStateProvider(true); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "task-becomes-inactive"; @@ -387,30 +522,38 @@ void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws Interrup @Test void testPoisonPillSentViaTransactionAwareEvent() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "poison-pill-test"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - // Enqueue a normal event first - queue.enqueueEvent(testEvent); + // Enqueue a normal event first and wait for processing + waitForEventProcessing(() -> queue.enqueueEvent(event)); // In the new architecture, QueueClosedEvent (poison pill) is sent via CDI events // when JpaDatabaseTaskStore.save() persists a final task and the transaction commits // ReplicatedQueueManager.onTaskFinalized() observes AFTER_SUCCESS and sends the poison pill // Simulate the CDI event observer being called (what happens in real execution) - TaskFinalizedEvent taskFinalizedEvent = new TaskFinalizedEvent(taskId); + // Create a final task for the event + Task finalTask = Task.builder() + .id(taskId) + .contextId("test-context") + .status(new TaskStatus(TaskState.COMPLETED)) + .build(); + TaskFinalizedEvent taskFinalizedEvent = new TaskFinalizedEvent(taskId, finalTask); // Call the observer method directly (simulating CDI event delivery) queueManager.onTaskFinalized(taskFinalizedEvent); - // Verify that QueueClosedEvent was replicated - // strategy.getCallCount() should be 2: one for testEvent, one for QueueClosedEvent - assertEquals(2, strategy.getCallCount(), "Should have replicated both normal event and QueueClosedEvent"); + // Verify that final Task and QueueClosedEvent were replicated + // strategy.getCallCount() should be 3: testEvent, final Task, then QueueClosedEvent (poison pill) + assertEquals(3, strategy.getCallCount(), "Should have replicated testEvent, final Task, and QueueClosedEvent"); + // Verify the last event is QueueClosedEvent (poison pill) Event lastEvent = strategy.getLastEvent(); - assertTrue(lastEvent instanceof QueueClosedEvent, "Last replicated event should be QueueClosedEvent"); + assertTrue(lastEvent instanceof QueueClosedEvent, "Last replicated event should be QueueClosedEvent (poison pill)"); assertEquals(taskId, ((QueueClosedEvent) lastEvent).getTaskId()); } @@ -451,36 +594,21 @@ void testQueueClosedEventJsonSerialization() throws Exception { @Test void testReplicatedQueueClosedEventTerminatesConsumer() throws InterruptedException { String taskId = "remote-close-test"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId EventQueue queue = queueManager.createOrTap(taskId); - // Enqueue a normal event - queue.enqueueEvent(testEvent); - // Simulate receiving QueueClosedEvent from remote node QueueClosedEvent closedEvent = new QueueClosedEvent(taskId); ReplicatedEventQueueItem replicatedClosedEvent = new ReplicatedEventQueueItem(taskId, closedEvent); - queueManager.onReplicatedEvent(replicatedClosedEvent); - // Dequeue the normal event first - EventQueueItem item1; - try { - item1 = queue.dequeueEventItem(100); - } catch (EventQueueClosedException e) { - fail("Should not throw on first dequeue"); - return; - } - assertNotNull(item1); - assertEquals(testEvent, item1.getEvent()); + // Dequeue the normal event first (use callback to wait for async processing) + EventQueueItem item1 = dequeueEventWithRetry(queue, () -> queue.enqueueEvent(eventForTask)); + assertNotNull(item1, "First event should be available"); + assertEquals(eventForTask, item1.getEvent()); - // Next dequeue should get the QueueClosedEvent - EventQueueItem item2; - try { - item2 = queue.dequeueEventItem(100); - } catch (EventQueueClosedException e) { - fail("Should not throw on second dequeue, should return the event"); - return; - } - assertNotNull(item2); + // Next dequeue should get the QueueClosedEvent (use callback to wait for async processing) + EventQueueItem item2 = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedClosedEvent)); + assertNotNull(item2, "QueueClosedEvent should be available"); assertTrue(item2.getEvent() instanceof QueueClosedEvent, "Second event should be QueueClosedEvent"); } @@ -539,4 +667,25 @@ public void setActive(boolean active) { this.active = active; } } + + /** + * Helper method to dequeue an event after waiting for MainEventBusProcessor distribution. + * Uses callback-based waiting instead of polling for deterministic synchronization. + * + * @param queue the queue to dequeue from + * @param enqueueAction the action that enqueues the event (triggers event processing) + * @return the dequeued EventQueueItem, or null if queue is closed + */ + private EventQueueItem dequeueEventWithRetry(EventQueue queue, Runnable enqueueAction) throws InterruptedException { + // Wait for event to be processed and distributed + waitForEventProcessing(enqueueAction); + + // Event is now available, dequeue directly + try { + return queue.dequeueEventItem(100); + } catch (EventQueueClosedException e) { + // Queue closed, return null + return null; + } + } } \ No newline at end of file diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java new file mode 100644 index 000000000..a91575aaa --- /dev/null +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +public class EventQueueUtil { + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + public static void stop(MainEventBusProcessor processor) { + processor.stop(); + } +} diff --git a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties index d0692ca53..ea64096cb 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties +++ b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties @@ -34,5 +34,7 @@ quarkus.messaging.kafka.health.timeout=5s # Enable debug logging quarkus.log.category."io.a2a.server.events".level=DEBUG quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG quarkus.log.category."io.a2a.extras.queuemanager.replicated".level=DEBUG +quarkus.log.category."io.a2a.extras.taskstore.database.jpa".level=DEBUG quarkus.log.category."io.a2a.client".level=DEBUG diff --git a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties index 0b647f3a5..d9a495e8f 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties +++ b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties @@ -34,5 +34,7 @@ quarkus.messaging.kafka.health.timeout=5s # Enable debug logging quarkus.log.category."io.a2a.server.events".level=DEBUG quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG quarkus.log.category."io.a2a.extras.queuemanager.replicated".level=DEBUG +quarkus.log.category."io.a2a.extras.taskstore.database.jpa".level=DEBUG quarkus.log.category."io.a2a.client".level=DEBUG diff --git a/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java b/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java index 93093388d..b98f0e87d 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java +++ b/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -254,9 +255,11 @@ public void testMultiInstanceEventReplication() throws Exception { final String taskId = "replication-test-task-" + System.currentTimeMillis(); final String contextId = "replication-test-context"; - // Step 1: Send initial message NON-streaming to create task - Message initialMessage = Message.builder(A2A.toUserMessage("Initial test message")) - .taskId(taskId) + Throwable testFailure = null; + try { + // Step 1: Send initial message NON-streaming to create task + Message initialMessage = Message.builder(A2A.toUserMessage("Initial test message")) + .taskId(taskId) .contextId(contextId) .build(); @@ -308,11 +311,33 @@ public void testMultiInstanceEventReplication() throws Exception { AtomicReference app1Error = new AtomicReference<>(); AtomicReference app2Error = new AtomicReference<>(); + AtomicBoolean app1ReceivedInitialTask = new AtomicBoolean(false); + AtomicBoolean app2ReceivedInitialTask = new AtomicBoolean(false); // App1 subscriber BiConsumer app1Subscriber = (event, card) -> { + String eventDetail = event.getClass().getSimpleName(); + if (event instanceof io.a2a.client.TaskUpdateEvent tue) { + eventDetail += " [" + tue.getUpdateEvent().getClass().getSimpleName(); + if (tue.getUpdateEvent() instanceof io.a2a.spec.TaskStatusUpdateEvent statusEvent) { + eventDetail += ", state=" + (statusEvent.status() != null ? statusEvent.status().state() : "null"); + } + eventDetail += "]"; + } else if (event instanceof io.a2a.client.TaskEvent te) { + eventDetail += " [state=" + (te.getTask().status() != null ? te.getTask().status().state() : "null") + "]"; + } + System.out.println("APP1 received event: " + eventDetail); + + // Per A2A spec 3.1.6: Handle initial TaskEvent on resubscribe + if (!app1ReceivedInitialTask.get() && event instanceof io.a2a.client.TaskEvent) { + app1ReceivedInitialTask.set(true); + System.out.println("APP1 filtered initial TaskEvent"); + // Don't count initial TaskEvent toward expected artifact/status events + return; + } app1Events.add(event); - app1EventCount.incrementAndGet(); + int count = app1EventCount.incrementAndGet(); + System.out.println("APP1 event count now: " + count + ", event: " + eventDetail); }; Consumer app1ErrorHandler = error -> { @@ -323,8 +348,28 @@ public void testMultiInstanceEventReplication() throws Exception { // App2 subscriber BiConsumer app2Subscriber = (event, card) -> { + String eventDetail = event.getClass().getSimpleName(); + if (event instanceof io.a2a.client.TaskUpdateEvent tue) { + eventDetail += " [" + tue.getUpdateEvent().getClass().getSimpleName(); + if (tue.getUpdateEvent() instanceof io.a2a.spec.TaskStatusUpdateEvent statusEvent) { + eventDetail += ", state=" + (statusEvent.status() != null ? statusEvent.status().state() : "null"); + } + eventDetail += "]"; + } else if (event instanceof io.a2a.client.TaskEvent te) { + eventDetail += " [state=" + (te.getTask().status() != null ? te.getTask().status().state() : "null") + "]"; + } + System.out.println("APP2 received event: " + eventDetail); + + // Per A2A spec 3.1.6: Handle initial TaskEvent on resubscribe + if (!app2ReceivedInitialTask.get() && event instanceof io.a2a.client.TaskEvent) { + app2ReceivedInitialTask.set(true); + System.out.println("APP2 filtered initial TaskEvent"); + // Don't count initial TaskEvent toward expected artifact/status events + return; + } app2Events.add(event); - app2EventCount.incrementAndGet(); + int count = app2EventCount.incrementAndGet(); + System.out.println("APP2 event count now: " + count + ", event: " + eventDetail); }; Consumer app2ErrorHandler = error -> { @@ -409,9 +454,45 @@ public void testMultiInstanceEventReplication() throws Exception { throw new AssertionError("App2 subscriber error", app2Error.get()); } - // Verify both received at least 3 events (could be more due to initial state events) - assertTrue(app1Events.size() >= 3, "App1 should receive at least 3 events, got: " + app1Events.size()); - assertTrue(app2Events.size() >= 3, "App2 should receive at least 3 events, got: " + app2Events.size()); + // Verify both received at least 3 events (could be more due to initial state events) + assertTrue(app1Events.size() >= 3, "App1 should receive at least 3 events, got: " + app1Events.size()); + assertTrue(app2Events.size() >= 3, "App2 should receive at least 3 events, got: " + app2Events.size()); + } catch (Throwable t) { + testFailure = t; + throw t; + } finally { + // Output container logs if test failed + if (testFailure != null) { + System.err.println("\n========================================"); + System.err.println("TEST FAILED - Dumping container logs"); + System.err.println("========================================\n"); + + dumpContainerLogs("KAFKA", kafka, 100); + dumpContainerLogs("APP1", app1, 200); + dumpContainerLogs("APP2", app2, 200); + + System.err.println("\n========================================"); + System.err.println("END OF CONTAINER LOGS"); + System.err.println("========================================\n"); + } + } + } + + /** + * Dumps the last N lines of logs from a container to stderr. + */ + private void dumpContainerLogs(String containerName, org.testcontainers.containers.ContainerState container, int lastLines) { + System.err.println("\n--- " + containerName + " LOGS (last " + lastLines + " lines) ---"); + try { + String logs = container.getLogs(); + String[] lines = logs.split("\n"); + int start = Math.max(0, lines.length - lastLines); + for (int i = start; i < lines.length; i++) { + System.err.println(lines[i]); + } + } catch (Exception e) { + System.err.println("Failed to retrieve " + containerName + " logs: " + e.getMessage()); + } } /** diff --git a/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java b/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java index 3825f3bad..c4b23fb0d 100644 --- a/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java +++ b/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java @@ -222,9 +222,21 @@ public void testKafkaEventReceivedByA2AServer() throws Exception { AtomicReference receivedCompletedEvent = new AtomicReference<>(); AtomicBoolean wasUnexpectedEvent = new AtomicBoolean(false); AtomicReference errorRef = new AtomicReference<>(); + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); // Create consumer to handle resubscribed events BiConsumer consumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!receivedInitialTask.get()) { + if (event instanceof TaskEvent) { + receivedInitialTask.set(true); + return; + } else { + throw new AssertionError("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent taskUpdateEvent) { if (taskUpdateEvent.getUpdateEvent() instanceof TaskStatusUpdateEvent statusEvent) { if (statusEvent.status().state() == TaskState.COMPLETED) { diff --git a/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java b/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java index b65b71650..a4afa8588 100644 --- a/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java +++ b/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java @@ -65,19 +65,23 @@ void initConfig() { @Transactional @Override - public void save(Task task) { - LOGGER.debug("Saving task with ID: {}", task.id()); + public void save(Task task, boolean isReplicated) { + LOGGER.debug("Saving task with ID: {} (replicated: {})", task.id(), isReplicated); try { JpaTask jpaTask = JpaTask.createFromTask(task); em.merge(jpaTask); LOGGER.debug("Persisted/updated task with ID: {}", task.id()); - if (task.status() != null && task.status().state() != null && task.status().state().isFinal()) { + // Only fire TaskFinalizedEvent for locally-generated final states, NOT for replicated events + // This prevents feedback loops where receiving a replicated final task triggers another replication + if (!isReplicated && task.status() != null && task.status().state() != null && task.status().state().isFinal()) { // Fire CDI event if task reached final state // IMPORTANT: The event will be delivered AFTER transaction commits (AFTER_SUCCESS observers) - // This ensures the task's final state is durably stored before the QueueClosedEvent poison pill is sent - LOGGER.debug("Task {} is in final state, firing TaskFinalizedEvent", task.id()); - taskFinalizedEvent.fire(new TaskFinalizedEvent(task.id())); + // This ensures the task's final state is durably stored before the final task and poison pill are sent + LOGGER.debug("Task {} is in final state, firing TaskFinalizedEvent with full Task", task.id()); + taskFinalizedEvent.fire(new TaskFinalizedEvent(task.id(), task)); + } else if (isReplicated && task.status() != null && task.status().state() != null && task.status().state().isFinal()) { + LOGGER.debug("Task {} is in final state but from replication - NOT firing TaskFinalizedEvent (prevents feedback loop)", task.id()); } } catch (JsonProcessingException e) { LOGGER.error("Failed to serialize task with ID: {}", task.id(), e); diff --git a/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java b/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java index ea77f73c7..15c01a626 100644 --- a/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java +++ b/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java @@ -54,7 +54,7 @@ public void testSaveAndRetrieveTask() { .build(); // Save the task - taskStore.save(task); + taskStore.save(task, false); // Retrieve the task Task retrieved = taskStore.get("test-task-1"); @@ -84,7 +84,7 @@ public void testSaveAndRetrieveTaskWithHistory() { .build(); // Save the task - taskStore.save(task); + taskStore.save(task, false); // Retrieve the task Task retrieved = taskStore.get("test-task-2"); @@ -108,7 +108,7 @@ public void testUpdateExistingTask() { .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Update the task Task updatedTask = Task.builder() @@ -117,7 +117,7 @@ public void testUpdateExistingTask() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(updatedTask); + taskStore.save(updatedTask, false); // Retrieve and verify the update Task retrieved = taskStore.get("test-task-3"); @@ -144,7 +144,7 @@ public void testDeleteTask() { .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Verify it exists assertNotNull(taskStore.get("test-task-4")); @@ -180,7 +180,7 @@ public void testTaskWithComplexMetadata() { .build(); // Save and retrieve - taskStore.save(task); + taskStore.save(task, false); Task retrieved = taskStore.get("test-task-5"); assertNotNull(retrieved); @@ -201,7 +201,7 @@ public void testIsTaskActiveForNonFinalTask() { .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Task should be active (not in final state) JpaDatabaseTaskStore jpaDatabaseTaskStore = (JpaDatabaseTaskStore) taskStore; @@ -220,7 +220,7 @@ public void testIsTaskActiveForFinalTaskWithinGracePeriod() { .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Update to final state Task finalTask = Task.builder() @@ -229,7 +229,7 @@ public void testIsTaskActiveForFinalTaskWithinGracePeriod() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(finalTask); + taskStore.save(finalTask, false); // Task should be active (within grace period - default 15 seconds) JpaDatabaseTaskStore jpaDatabaseTaskStore = (JpaDatabaseTaskStore) taskStore; @@ -248,7 +248,7 @@ public void testIsTaskActiveForFinalTaskBeyondGracePeriod() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Directly update the finalizedAt timestamp to 20 seconds in the past // (beyond the default 15-second grace period) @@ -322,9 +322,9 @@ public void testListTasksFilterByContextId() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(task1); - taskStore.save(task2); - taskStore.save(task3); + taskStore.save(task1, false); + taskStore.save(task2, false); + taskStore.save(task3, false); // List tasks for context-A ListTasksParams params = ListTasksParams.builder() @@ -361,9 +361,9 @@ public void testListTasksFilterByStatus() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(task1); - taskStore.save(task2); - taskStore.save(task3); + taskStore.save(task1, false); + taskStore.save(task2, false); + taskStore.save(task3, false); // List only WORKING tasks in this context ListTasksParams params = ListTasksParams.builder() @@ -401,9 +401,9 @@ public void testListTasksCombinedFilters() { .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task1); - taskStore.save(task2); - taskStore.save(task3); + taskStore.save(task1, false); + taskStore.save(task2, false); + taskStore.save(task3, false); // List WORKING tasks in context-X ListTasksParams params = ListTasksParams.builder() @@ -432,7 +432,7 @@ public void testListTasksPagination() { .contextId("context-pagination") .status(new TaskStatus(TaskState.SUBMITTED, null, sameTimestamp)) .build(); - taskStore.save(task); + taskStore.save(task, false); } // First page: pageSize=2 @@ -488,7 +488,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(10))) .build(); - taskStore.save(task1); + taskStore.save(task1, false); // Task 2: 5 minutes ago, ID="task-diff-b" Task task2 = Task.builder() @@ -496,7 +496,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(5))) .build(); - taskStore.save(task2); + taskStore.save(task2, false); // Task 3: 5 minutes ago, ID="task-diff-c" (same timestamp as task2, tests ID tie-breaker) Task task3 = Task.builder() @@ -504,7 +504,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(5))) .build(); - taskStore.save(task3); + taskStore.save(task3, false); // Task 4: Now, ID="task-diff-d" Task task4 = Task.builder() @@ -512,7 +512,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now)) .build(); - taskStore.save(task4); + taskStore.save(task4, false); // Task 5: 1 minute ago, ID="task-diff-e" Task task5 = Task.builder() @@ -520,7 +520,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(1))) .build(); - taskStore.save(task5); + taskStore.save(task5, false); // Expected order (timestamp DESC, id ASC): // 1. task-diff-d (now) @@ -616,7 +616,7 @@ public void testListTasksHistoryLimiting() { .history(longHistory) .build(); - taskStore.save(task); + taskStore.save(task, false); // List with historyLength=3 (should keep only last 3 messages) - filter by unique context ListTasksParams params = ListTasksParams.builder() @@ -654,7 +654,7 @@ public void testListTasksArtifactInclusion() { .artifacts(artifacts) .build(); - taskStore.save(task); + taskStore.save(task, false); // List without artifacts (default) - filter by unique context ListTasksParams paramsWithoutArtifacts = ListTasksParams.builder() @@ -691,7 +691,7 @@ public void testListTasksDefaultPageSize() { .contextId("context-default-pagesize") .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - taskStore.save(task); + taskStore.save(task, false); } // List without specifying pageSize (should use default of 50) @@ -715,7 +715,7 @@ public void testListTasksInvalidPageTokenFormat() { .contextId("context-invalid-token") .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Test 1: Legacy ID-only pageToken should throw InvalidParamsError ListTasksParams params1 = ListTasksParams.builder() @@ -777,9 +777,9 @@ public void testListTasksOrderingById() { .build(); // Save in reverse order - taskStore.save(task3); - taskStore.save(task1); - taskStore.save(task2); + taskStore.save(task3, false); + taskStore.save(task1, false); + taskStore.save(task2, false); // List should return sorted by timestamp DESC (all same), then by ID ASC ListTasksParams params = ListTasksParams.builder() diff --git a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 18e18a2f1..cb5bdb25b 100644 --- a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -13,7 +13,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; @@ -21,6 +20,7 @@ import com.google.gson.JsonSyntaxException; import io.a2a.common.A2AHeaders; +import io.a2a.server.util.sse.SseFormatter; import io.a2a.grpc.utils.JSONRPCUtils; import io.a2a.jsonrpc.common.json.IdJsonMappingException; import io.a2a.jsonrpc.common.json.InvalidParamsJsonMappingException; @@ -65,7 +65,6 @@ import io.a2a.transport.jsonrpc.handler.JSONRPCHandler; import io.quarkus.security.Authenticated; import io.quarkus.vertx.web.Body; -import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; import io.smallrye.mutiny.Multi; import io.vertx.core.AsyncResult; @@ -74,6 +73,8 @@ import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerResponse; import io.vertx.ext.web.RoutingContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @Singleton public class A2AServerRoutes { @@ -135,8 +136,12 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { } else if (streaming) { final Multi> finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - finalStreamingResponse.map(i -> (Object) i), rc); + // Convert Multi to Multi with SSE formatting + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = finalStreamingResponse + .map(response -> SseFormatter.formatResponseAsSSE(response, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } else { @@ -295,34 +300,30 @@ private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse + * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final Logger logger = LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else { - subscription.request(1); - } - } - - public static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.Subscription upstream; @Override @@ -330,6 +331,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.info("SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -338,54 +346,50 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + response.setChunked(true); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + upstream.cancel(); + rc.fail(ar.cause()); + } else { + upstream.request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + upstream.cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - public static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - String data = serializeResponse((A2AResponse) ev.data()); - return Buffer.buffer(e + "data: " + data + "\nid: " + id + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } - String data = serializeResponse((A2AResponse) o); - return Buffer.buffer("data: " + data + "\nid: " + count.getAndIncrement() + "\n\n"); - } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + response.end(); } - } - response.end(); + }); } } } diff --git a/reference/jsonrpc/src/test/resources/application.properties b/reference/jsonrpc/src/test/resources/application.properties index 7b9cea9cc..e612925d4 100644 --- a/reference/jsonrpc/src/test/resources/application.properties +++ b/reference/jsonrpc/src/test/resources/application.properties @@ -1 +1,6 @@ quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient + +# Debug logging for event processing and request handling +quarkus.log.category."io.a2a.server.events".level=DEBUG +quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java index 46d0d38e6..7a50f0afb 100644 --- a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java @@ -15,7 +15,8 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; + +import io.a2a.server.util.sse.SseFormatter; import jakarta.annotation.security.PermitAll; import jakarta.enterprise.inject.Instance; @@ -38,7 +39,6 @@ import io.a2a.util.Utils; import io.quarkus.security.Authenticated; import io.quarkus.vertx.web.Body; -import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; import io.smallrye.mutiny.Multi; import io.vertx.core.AsyncResult; @@ -110,10 +110,14 @@ public void sendMessageStreaming(@Body String body, RoutingContext rc) { if (error != null) { sendResponse(rc, error); } else if (streamingResponse != null) { - Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - events.map(i -> (Object) i), rc); + // Convert Flow.Publisher (JSON) to Multi (SSE-formatted) + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher()) + .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } } @@ -243,10 +247,14 @@ public void subscribeToTask(RoutingContext rc) { if (error != null) { sendResponse(rc, error); } else if (streamingResponse != null) { - Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - events.map(i -> (Object) i), rc); + // Convert Flow.Publisher (JSON) to Multi (SSE-formatted) + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher()) + .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } } @@ -450,34 +458,30 @@ public String getUsername() { } } - // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API + /** + * Simplified SSE support for Vert.x/Quarkus. + *

+ * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.@Nullable Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else if (subscription != null) { - subscription.request(1); - } - } - - private static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.@Nullable Subscription upstream; @Override @@ -485,6 +489,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.debug("REST SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -493,53 +504,64 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + // Additional SSE headers to prevent buffering + headers.set("Cache-Control", "no-cache"); + headers.set("X-Accel-Buffering", "no"); // Disable nginx buffering + response.setChunked(true); + + // CRITICAL: Disable write queue max size to prevent buffering + // Vert.x buffers writes by default - we need immediate flushing for SSE + response.setWriteQueueMaxSize(1); // Force immediate flush + + // Send initial SSE comment to kickstart the stream + // This forces Vert.x to send headers and start the stream immediately + response.write(": SSE stream started\n\n"); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); + rc.fail(ar.cause()); + } else { + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - private static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - return Buffer.buffer(e + "data: " + ev.data() + "\nid: " + id + "\n\n"); - } else { - return Buffer.buffer("data: " + o + "\nid: " + count.getAndIncrement() + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } + response.end(); } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - } - response.end(); + }); } } diff --git a/server-common/src/main/java/io/a2a/server/ServerCallContext.java b/server-common/src/main/java/io/a2a/server/ServerCallContext.java index ba5c20b95..c12c60c21 100644 --- a/server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -16,6 +16,7 @@ public class ServerCallContext { private final Set requestedExtensions; private final Set activatedExtensions; private final @Nullable String requestedProtocolVersion; + private volatile @Nullable Runnable eventConsumerCancelCallback; public ServerCallContext(User user, Map state, Set requestedExtensions) { this(user, state, requestedExtensions, null); @@ -64,4 +65,64 @@ public boolean isExtensionRequested(String extensionUri) { public @Nullable String getRequestedProtocolVersion() { return requestedProtocolVersion; } + + /** + * Sets the callback to be invoked when the client disconnects or the call is cancelled. + *

+ * This callback is typically used to stop the EventConsumer polling loop when a client + * disconnects from a streaming endpoint. The callback is invoked by transport layers + * (JSON-RPC over HTTP/SSE, REST over HTTP/SSE, gRPC streaming) when they detect that + * the client has closed the connection. + *

+ *

+ * Thread Safety: The callback may be invoked from any thread, depending + * on the transport implementation. Implementations should be thread-safe. + *

+ * Example Usage: + *
{@code
+     * EventConsumer consumer = new EventConsumer(queue);
+     * context.setEventConsumerCancelCallback(consumer::cancel);
+     * }
+ * + * @param callback the callback to invoke on client disconnect, or null to clear any existing callback + * @see #invokeEventConsumerCancelCallback() + */ + public void setEventConsumerCancelCallback(@Nullable Runnable callback) { + this.eventConsumerCancelCallback = callback; + } + + /** + * Invokes the EventConsumer cancel callback if one has been set. + *

+ * This method is called by transport layers when a client disconnects or cancels a + * streaming request. It triggers the callback registered via + * {@link #setEventConsumerCancelCallback(Runnable)}, which typically stops the + * EventConsumer polling loop. + *

+ *

+ * Transport-Specific Behavior: + *

+ *
    + *
  • JSON-RPC/REST over HTTP/SSE: Called from Vert.x + * {@code HttpServerResponse.closeHandler()} when the SSE connection is closed
  • + *
  • gRPC streaming: Called from gRPC + * {@code Context.CancellationListener.cancelled()} when the call is cancelled
  • + *
+ *

+ * Thread Safety: This method is thread-safe. The callback is stored + * in a volatile field and null-checked before invocation to prevent race conditions. + *

+ *

+ * If no callback has been set, this method does nothing (no-op). + *

+ * + * @see #setEventConsumerCancelCallback(Runnable) + * @see io.a2a.server.events.EventConsumer#cancel() + */ + public void invokeEventConsumerCancelCallback() { + Runnable callback = this.eventConsumerCancelCallback; + if (callback != null) { + callback.run(); + } + } } diff --git a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java index d4fe5b395..7c7b28452 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java +++ b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -19,10 +19,17 @@ public class EventConsumer { private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumer.class); private final EventQueue queue; private volatile @Nullable Throwable error; + private volatile boolean cancelled = false; + private volatile boolean agentCompleted = false; + private volatile int pollTimeoutsAfterAgentCompleted = 0; private static final String ERROR_MSG = "Agent did not return any response"; private static final int NO_WAIT = -1; private static final int QUEUE_WAIT_MILLISECONDS = 500; + // In replicated scenarios, events can arrive hundreds of milliseconds after local agent completes + // Grace period allows Kafka replication to deliver late-arriving events + // 3 timeouts * 500ms = 1500ms grace period for replication delays + private static final int MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED = 3; public EventConsumer(EventQueue queue) { this.queue = queue; @@ -45,6 +52,14 @@ public Flow.Publisher consumeAll() { boolean completed = false; try { while (true) { + // Check if cancelled by client disconnect + if (cancelled) { + LOGGER.debug("EventConsumer detected cancellation, exiting polling loop for queue {}", System.identityHashCode(queue)); + completed = true; + tube.complete(); + return; + } + if (error != null) { completed = true; tube.fail(error); @@ -60,13 +75,49 @@ public Flow.Publisher consumeAll() { EventQueueItem item; Event event; try { + LOGGER.debug("EventConsumer polling queue {} (error={}, agentCompleted={})", + System.identityHashCode(queue), error, agentCompleted); item = queue.dequeueEventItem(QUEUE_WAIT_MILLISECONDS); if (item == null) { + int queueSize = queue.size(); + LOGGER.debug("EventConsumer poll timeout (null item), agentCompleted={}, queue.size()={}, timeoutCount={}", + agentCompleted, queueSize, pollTimeoutsAfterAgentCompleted); + // If agent completed, a poll timeout means no more events are coming + // MainEventBusProcessor has 500ms to distribute events from MainEventBus + // If we timeout with agentCompleted=true, all events have been distributed + // + // IMPORTANT: In replicated scenarios, remote events may arrive AFTER local agent completes! + // Use grace period to allow for Kafka replication delays (can be 400-500ms) + if (agentCompleted && queueSize == 0) { + pollTimeoutsAfterAgentCompleted++; + if (pollTimeoutsAfterAgentCompleted >= MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED) { + LOGGER.debug("Agent completed with {} consecutive poll timeouts and empty queue, closing for graceful completion (queue={})", + pollTimeoutsAfterAgentCompleted, System.identityHashCode(queue)); + queue.close(); + completed = true; + tube.complete(); + return; + } else { + LOGGER.debug("Agent completed but grace period active ({}/{} timeouts), continuing to poll (queue={})", + pollTimeoutsAfterAgentCompleted, MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED, System.identityHashCode(queue)); + } + } else if (agentCompleted && queueSize > 0) { + LOGGER.debug("Agent completed but queue has {} pending events, resetting timeout counter and continuing to poll (queue={})", + queueSize, System.identityHashCode(queue)); + pollTimeoutsAfterAgentCompleted = 0; // Reset counter when events arrive + } continue; } + // Event received - reset timeout counter + pollTimeoutsAfterAgentCompleted = 0; event = item.getEvent(); + LOGGER.debug("EventConsumer received event: {} (queue={})", + event.getClass().getSimpleName(), System.identityHashCode(queue)); + // Defensive logging for error handling if (event instanceof Throwable thr) { + LOGGER.debug("EventConsumer detected Throwable event: {} - triggering tube.fail()", + thr.getClass().getSimpleName()); tube.fail(thr); return; } @@ -90,14 +141,30 @@ public Flow.Publisher consumeAll() { // Only send event if it's not a QueueClosedEvent // QueueClosedEvent is an internal coordination event used for replication // and should not be exposed to API consumers + boolean isFinalSent = false; if (!(event instanceof QueueClosedEvent)) { tube.send(item); + isFinalSent = isFinalEvent; } if (isFinalEvent) { LOGGER.debug("Final or interrupted event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue)); queue.close(); LOGGER.debug("Queue closed, breaking loop for queue {}", System.identityHashCode(queue)); + + // CRITICAL: Allow tube buffer to flush before calling tube.complete() + // tube.send() buffers events asynchronously. If we call tube.complete() immediately, + // the stream-end signal can reach the client BEFORE the buffered final event, + // causing the client to close the connection and never receive the final event. + // This is especially important in replicated scenarios where events arrive via Kafka + // and timing is less deterministic. A small delay ensures the buffer flushes. + if (isFinalSent) { + try { + Thread.sleep(50); // 50ms to allow SSE buffer flush + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } break; } } catch (EventQueueClosedException e) { @@ -138,12 +205,25 @@ private boolean isStreamTerminatingTask(Task task) { public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { return agentRunnable -> { + LOGGER.debug("EventConsumer: Agent done callback invoked (hasError={}, queue={})", + agentRunnable.getError() != null, System.identityHashCode(queue)); if (agentRunnable.getError() != null) { error = agentRunnable.getError(); + LOGGER.debug("EventConsumer: Set error field from agent callback"); + } else { + agentCompleted = true; + LOGGER.debug("EventConsumer: Agent completed successfully, set agentCompleted=true, will close queue after draining"); } }; } + public void cancel() { + // Set cancellation flag to stop polling loop + // Called when client disconnects without completing stream + LOGGER.debug("EventConsumer cancelled (client disconnect), stopping polling for queue {}", System.identityHashCode(queue)); + cancelled = true; + } + public void close() { // Close the queue to stop the polling loop in consumeAll() // This will cause EventQueueClosedException and exit the while(true) loop diff --git a/server-common/src/main/java/io/a2a/server/events/EventQueue.java b/server-common/src/main/java/io/a2a/server/events/EventQueue.java index a08f63084..99f8bc2dc 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventQueue.java +++ b/server-common/src/main/java/io/a2a/server/events/EventQueue.java @@ -1,6 +1,7 @@ package io.a2a.server.events; import java.util.List; +import java.util.Objects; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -11,6 +12,8 @@ import io.a2a.server.tasks.TaskStateProvider; import io.a2a.spec.Event; +import io.a2a.spec.Task; +import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,7 +26,7 @@ * and hierarchical queue structures via MainQueue and ChildQueue implementations. *

*

- * Use {@link #builder()} to create configured instances or extend MainQueue/ChildQueue directly. + * Use {@link #builder(MainEventBus)} to create configured instances or extend MainQueue/ChildQueue directly. *

*/ public abstract class EventQueue implements AutoCloseable { @@ -36,14 +39,6 @@ public abstract class EventQueue implements AutoCloseable { public static final int DEFAULT_QUEUE_SIZE = 1000; private final int queueSize; - /** - * Internal blocking queue for storing event queue items. - */ - protected final BlockingQueue queue = new LinkedBlockingDeque<>(); - /** - * Semaphore for backpressure control, limiting the number of pending events. - */ - protected final Semaphore semaphore; private volatile boolean closed = false; /** @@ -64,7 +59,6 @@ protected EventQueue(int queueSize) { throw new IllegalArgumentException("Queue size must be greater than 0"); } this.queueSize = queueSize; - this.semaphore = new Semaphore(queueSize, true); LOGGER.trace("Creating {} with queue size: {}", this, queueSize); } @@ -78,8 +72,8 @@ protected EventQueue(EventQueue parent) { LOGGER.trace("Creating {}, parent: {}", this, parent); } - static EventQueueBuilder builder() { - return new EventQueueBuilder(); + static EventQueueBuilder builder(MainEventBus mainEventBus) { + return new EventQueueBuilder().mainEventBus(mainEventBus); } /** @@ -95,6 +89,7 @@ public static class EventQueueBuilder { private @Nullable String taskId; private List onCloseCallbacks = new java.util.ArrayList<>(); private @Nullable TaskStateProvider taskStateProvider; + private @Nullable MainEventBus mainEventBus; /** * Sets the maximum queue size. @@ -153,17 +148,31 @@ public EventQueueBuilder taskStateProvider(TaskStateProvider taskStateProvider) return this; } + /** + * Sets the main event bus + * + * @param mainEventBus the main event bus + * @return this builder + */ + public EventQueueBuilder mainEventBus(MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; + return this; + } + /** * Builds and returns the configured EventQueue. * * @return a new MainQueue instance */ public EventQueue build() { - if (hook != null || !onCloseCallbacks.isEmpty() || taskStateProvider != null) { - return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider); - } else { - return new MainQueue(queueSize); + // MainEventBus is REQUIRED - enforce single architectural path + if (mainEventBus == null) { + throw new IllegalStateException("MainEventBus is required for EventQueue creation"); } + if (taskId == null) { + throw new IllegalStateException("taskId is required for EventQueue creation"); + } + return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider, mainEventBus); } } @@ -209,21 +218,39 @@ public void enqueueEvent(Event event) { * @param item the event queue item to enqueue * @throws RuntimeException if interrupted while waiting to acquire the semaphore */ - public void enqueueItem(EventQueueItem item) { - Event event = item.getEvent(); - if (closed) { - LOGGER.warn("Queue is closed. Event will not be enqueued. {} {}", this, event); - return; - } - // Call toString() since for errors we don't really want the full stacktrace - try { - semaphore.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); - } - queue.add(item); - LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + public abstract void enqueueItem(EventQueueItem item); + + /** + * Enqueues an event directly to this specific queue only, bypassing the MainEventBus. + *

+ * This method is used for enqueuing already-persisted events (e.g., current task state + * on resubscribe) that should only be sent to this specific subscriber, not distributed + * to all children or sent through MainEventBusProcessor. + *

+ *

+ * Default implementation throws UnsupportedOperationException. Only ChildQueue supports this. + *

+ * + * @param item the event queue item to enqueue directly + * @throws UnsupportedOperationException if called on MainQueue or other queue types + */ + public void enqueueLocalOnly(EventQueueItem item) { + throw new UnsupportedOperationException( + "enqueueLocalOnly is only supported on ChildQueue for resubscribe scenarios"); + } + + /** + * Enqueues an event directly to this specific queue only, bypassing the MainEventBus. + *

+ * Convenience method that wraps the event in a LocalEventQueueItem before calling + * {@link #enqueueLocalOnly(EventQueueItem)}. + *

+ * + * @param event the event to enqueue directly + * @throws UnsupportedOperationException if called on MainQueue or other queue types + */ + public void enqueueEventLocalOnly(Event event) { + enqueueLocalOnly(new LocalEventQueueItem(event)); } /** @@ -244,48 +271,17 @@ public void enqueueItem(EventQueueItem item) { * This method returns the full EventQueueItem wrapper, allowing callers to check * metadata like whether the event is replicated via {@link EventQueueItem#isReplicated()}. *

+ *

+ * Note: MainQueue does not support dequeue operations - only ChildQueues can be consumed. + *

* * @param waitMilliSeconds the maximum time to wait in milliseconds * @return the EventQueueItem, or null if timeout occurs * @throws EventQueueClosedException if the queue is closed and empty + * @throws UnsupportedOperationException if called on MainQueue */ - public @Nullable EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { - if (closed && queue.isEmpty()) { - LOGGER.debug("Queue is closed, and empty. Sending termination message. {}", this); - throw new EventQueueClosedException(); - } - try { - if (waitMilliSeconds <= 0) { - EventQueueItem item = queue.poll(); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } - return item; - } - try { - LOGGER.trace("Polling queue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); - EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } else { - LOGGER.trace("Dequeue timeout (null) from queue {}", System.identityHashCode(this)); - } - return item; - } catch (InterruptedException e) { - LOGGER.debug("Interrupted dequeue (waiting) {}", this); - Thread.currentThread().interrupt(); - return null; - } - } finally { - signalQueuePollerStarted(); - } - } + @Nullable + public abstract EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException; /** * Placeholder method for task completion notification. @@ -295,6 +291,18 @@ public void taskDone() { // TODO Not sure if needed yet. BlockingQueue.poll()/.take() remove the events. } + /** + * Returns the current size of the queue. + *

+ * For MainQueue: returns the number of events in-flight (in MainEventBus queue + currently being processed). + * This reflects actual capacity usage tracked by the semaphore. + * For ChildQueue: returns the size of the local consumption queue. + *

+ * + * @return the number of events currently in the queue + */ + public abstract int size(); + /** * Closes this event queue gracefully, allowing pending events to be consumed. */ @@ -348,72 +356,71 @@ protected void doClose(boolean immediate) { LOGGER.debug("Closing {} (immediate={})", this, immediate); closed = true; } - - if (immediate) { - // Immediate close: clear pending events - queue.clear(); - LOGGER.debug("Cleared queue for immediate close: {}", this); - } - // For graceful close, let the queue drain naturally through normal consumption + // Subclasses handle immediate close logic (e.g., ChildQueue clears its local queue) } static class MainQueue extends EventQueue { private final List children = new CopyOnWriteArrayList<>(); + protected final Semaphore semaphore; private final CountDownLatch pollingStartedLatch = new CountDownLatch(1); private final AtomicBoolean pollingStarted = new AtomicBoolean(false); private final @Nullable EventEnqueueHook enqueueHook; - private final @Nullable String taskId; + private final String taskId; private final List onCloseCallbacks; private final @Nullable TaskStateProvider taskStateProvider; - - MainQueue() { - super(); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize) { - super(queueSize); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(EventEnqueueHook hook) { - super(); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, EventEnqueueHook hook) { - super(queueSize); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, @Nullable EventEnqueueHook hook, @Nullable String taskId, List onCloseCallbacks, @Nullable TaskStateProvider taskStateProvider) { + private final MainEventBus mainEventBus; + + MainQueue(int queueSize, + @Nullable EventEnqueueHook hook, + String taskId, + List onCloseCallbacks, + @Nullable TaskStateProvider taskStateProvider, + @Nullable MainEventBus mainEventBus) { super(queueSize); + this.semaphore = new Semaphore(queueSize, true); this.enqueueHook = hook; this.taskId = taskId; this.onCloseCallbacks = List.copyOf(onCloseCallbacks); // Defensive copy this.taskStateProvider = taskStateProvider; - LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks and TaskStateProvider: {}", + this.mainEventBus = Objects.requireNonNull(mainEventBus, "MainEventBus is required"); + LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks, TaskStateProvider: {}, MainEventBus configured", taskId, onCloseCallbacks.size(), taskStateProvider != null); } + public EventQueue tap() { ChildQueue child = new ChildQueue(this); children.add(child); return child; } + /** + * Returns the current number of child queues. + * Useful for debugging and logging event distribution. + */ + public int getChildCount() { + return children.size(); + } + + /** + * Returns the enqueue hook for replication (package-protected for MainEventBusProcessor). + */ + @Nullable EventEnqueueHook getEnqueueHook() { + return enqueueHook; + } + + @Override + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + throw new UnsupportedOperationException("MainQueue cannot be consumed directly - use tap() to create a ChildQueue for consumption"); + } + + @Override + public int size() { + // Return total in-flight events (in MainEventBus + being processed) + // This aligns with semaphore's capacity tracking + return getQueueSize() - semaphore.availablePermits(); + } + @Override public void enqueueItem(EventQueueItem item) { // MainQueue must accept events even when closed to support: @@ -424,6 +431,15 @@ public void enqueueItem(EventQueueItem item) { // We bypass the parent's closed check and enqueue directly Event event = item.getEvent(); + // Check if this is a final event BEFORE submitting to MainEventBus + // If it is, notify all children to expect it (so they wait for MainEventBusProcessor) + if (isFinalEvent(event)) { + LOGGER.debug("Final event detected, notifying {} children to expect it", children.size()); + for (ChildQueue child : children) { + child.expectFinalEvent(); + } + } + // Acquire semaphore for backpressure try { semaphore.acquire(); @@ -432,17 +448,27 @@ public void enqueueItem(EventQueueItem item) { throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); } - // Add to this MainQueue's internal queue - queue.add(item); LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); - // Distribute to all ChildQueues (they will receive the event even if MainQueue is closed) - children.forEach(eq -> eq.internalEnqueueItem(item)); + // Submit to MainEventBus for centralized persistence + distribution + // MainEventBus is guaranteed non-null by constructor requirement + // Note: Replication now happens in MainEventBusProcessor AFTER persistence - // Trigger replication hook if configured - if (enqueueHook != null) { - enqueueHook.onEnqueue(item); + // Submit event to MainEventBus with our taskId + mainEventBus.submit(taskId, this, item); + } + + /** + * Checks if an event represents a final task state. + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); } + return false; } @Override @@ -465,20 +491,15 @@ public void signalQueuePollerStarted() { void childClosing(ChildQueue child, boolean immediate) { children.remove(child); // Remove the closing child - // Close immediately if requested - if (immediate) { - LOGGER.debug("MainQueue closing immediately (immediate=true)"); - this.doClose(immediate); - return; - } - // If there are still children, keep queue open if (!children.isEmpty()) { LOGGER.debug("MainQueue staying open: {} children remaining", children.size()); return; } - // No children left - check if task is finalized before auto-closing + // No children left - check if task is finalized before closing + // IMPORTANT: This check must happen BEFORE the immediate flag check + // to prevent closing queues for non-final tasks (fire-and-forget, resubscription support) if (taskStateProvider != null && taskId != null) { boolean isFinalized = taskStateProvider.isTaskFinalized(taskId); if (!isFinalized) { @@ -493,6 +514,36 @@ void childClosing(ChildQueue child, boolean immediate) { this.doClose(immediate); } + /** + * Distribute event to all ChildQueues. + * Called by MainEventBusProcessor after TaskStore persistence. + */ + void distributeToChildren(EventQueueItem item) { + int childCount = children.size(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Distributing event {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + children.forEach(child -> { + LOGGER.debug("MainQueue[{}]: Enqueueing event {} to child queue", + taskId, item.getEvent().getClass().getSimpleName()); + child.internalEnqueueItem(item); + }); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Completed distribution of {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + } + + /** + * Release the semaphore after event processing is complete. + * Called by MainEventBusProcessor in finally block to ensure release even on exceptions. + * Balances the acquire() in enqueueEvent() - protects MainEventBus throughput. + */ + void releaseSemaphore() { + semaphore.release(); + } + /** * Get the count of active child queues. * Used for testing to verify reference counting mechanism. @@ -539,10 +590,17 @@ public void close(boolean immediate) { public void close(boolean immediate, boolean notifyParent) { throw new UnsupportedOperationException("MainQueue does not support notifyParent parameter - use close(boolean) instead"); } + + String getTaskId() { + return taskId; + } } static class ChildQueue extends EventQueue { private final MainQueue parent; + private final BlockingQueue queue = new LinkedBlockingDeque<>(); + private volatile boolean immediateClose = false; + private volatile boolean awaitingFinalEvent = false; public ChildQueue(MainQueue parent) { this.parent = parent; @@ -553,8 +611,94 @@ public void enqueueEvent(Event event) { parent.enqueueEvent(event); } + @Override + public void enqueueItem(EventQueueItem item) { + // ChildQueue delegates writes to parent MainQueue + parent.enqueueItem(item); + } + private void internalEnqueueItem(EventQueueItem item) { - super.enqueueItem(item); + // Internal method called by MainEventBusProcessor to add to local queue + // Note: Semaphore is managed by parent MainQueue (acquire/release), not ChildQueue + Event event = item.getEvent(); + // For graceful close: still accept events so they can be drained by EventConsumer + // For immediate close: reject events to stop distribution quickly + if (isClosed() && immediateClose) { + LOGGER.warn("ChildQueue is immediately closed. Event will not be enqueued. {} {}", this, event); + return; + } + if (!queue.offer(item)) { + LOGGER.warn("ChildQueue {} is full. Closing immediately.", this); + close(true); // immediate close + } else { + LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + + // If we were awaiting a final event and this is it, clear the flag + if (awaitingFinalEvent && isFinalEvent(event)) { + awaitingFinalEvent = false; + LOGGER.debug("ChildQueue {} received awaited final event", System.identityHashCode(this)); + } + } + } + + /** + * Checks if an event represents a final task state. + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); + } + return false; + } + + @Override + public void enqueueLocalOnly(EventQueueItem item) { + internalEnqueueItem(item); + } + + @Override + @Nullable + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + // For immediate close: exit immediately even if queue is not empty (race with MainEventBusProcessor) + // For graceful close: only exit when queue is empty (wait for all events to be consumed) + // BUT: if awaiting final event, keep polling even if closed and empty + if (isClosed() && (queue.isEmpty() || immediateClose) && !awaitingFinalEvent) { + LOGGER.debug("ChildQueue is closed{}, sending termination message. {} (queueSize={})", + immediateClose ? " (immediate)" : " and empty", + this, + queue.size()); + throw new EventQueueClosedException(); + } + try { + if (waitMilliSeconds <= 0) { + EventQueueItem item = queue.poll(); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); + } + return item; + } + try { + LOGGER.trace("Polling ChildQueue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); + EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); + } else { + LOGGER.trace("Dequeue timeout (null) from ChildQueue {}", System.identityHashCode(this)); + } + return item; + } catch (InterruptedException e) { + LOGGER.debug("Interrupted dequeue (waiting) {}", this); + Thread.currentThread().interrupt(); + return null; + } + } finally { + signalQueuePollerStarted(); + } } @Override @@ -562,6 +706,12 @@ public EventQueue tap() { throw new IllegalStateException("Can only tap the main queue"); } + @Override + public int size() { + // Return size of local consumption queue + return queue.size(); + } + @Override public void awaitQueuePollerStart() throws InterruptedException { parent.awaitQueuePollerStart(); @@ -572,6 +722,29 @@ public void signalQueuePollerStarted() { parent.signalQueuePollerStarted(); } + @Override + protected void doClose(boolean immediate) { + super.doClose(immediate); // Sets closed flag + if (immediate) { + // Immediate close: clear pending events from local queue + this.immediateClose = true; + int clearedCount = queue.size(); + queue.clear(); + LOGGER.debug("Cleared {} events from ChildQueue for immediate close: {}", clearedCount, this); + } + // For graceful close, let the queue drain naturally through normal consumption + } + + /** + * Notifies this ChildQueue to expect a final event. + * Called by MainQueue when it enqueues a final event, BEFORE submitting to MainEventBus. + * This ensures the ChildQueue keeps polling until the final event arrives (after MainEventBusProcessor). + */ + void expectFinalEvent() { + awaitingFinalEvent = true; + LOGGER.debug("ChildQueue {} now awaiting final event", System.identityHashCode(this)); + } + @Override public void close() { close(false); diff --git a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java index e5a17e0e7..53a089e4c 100644 --- a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java @@ -34,16 +34,20 @@ protected InMemoryQueueManager() { this.taskStateProvider = null; } + MainEventBus mainEventBus; + @Inject - public InMemoryQueueManager(TaskStateProvider taskStateProvider) { + public InMemoryQueueManager(TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; this.factory = new DefaultEventQueueFactory(); this.taskStateProvider = taskStateProvider; } - // For testing with custom factory - public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider) { + // For testing/extensions with custom factory and MainEventBus + public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { this.factory = factory; this.taskStateProvider = taskStateProvider; + this.mainEventBus = mainEventBus; } @Override @@ -101,7 +105,6 @@ public EventQueue createOrTap(String taskId) { EventQueue newQueue = null; if (existing == null) { // Use builder pattern for cleaner queue creation - // Use the new taskId-aware builder method if available newQueue = factory.builder(taskId).build(); // Make sure an existing queue has not been added in the meantime existing = queues.putIfAbsent(taskId, newQueue); @@ -128,6 +131,12 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep eventQueue.awaitQueuePollerStart(); } + @Override + public EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { + // Use the factory to ensure proper configuration (MainEventBus, callbacks, etc.) + return factory.builder(taskId); + } + @Override public int getActiveChildQueueCount(String taskId) { EventQueue queue = queues.get(taskId); @@ -142,6 +151,14 @@ public int getActiveChildQueueCount(String taskId) { return -1; } + @Override + public EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + return EventQueue.builder(mainEventBus) + .taskId(taskId) + .addOnCloseCallback(getCleanupCallback(taskId)) + .taskStateProvider(taskStateProvider); + } + /** * Get the cleanup callback that removes a queue from the map when it closes. * This is exposed so that subclasses (like ReplicatedQueueManager) can reuse @@ -181,11 +198,8 @@ public Runnable getCleanupCallback(String taskId) { private class DefaultEventQueueFactory implements EventQueueFactory { @Override public EventQueue.EventQueueBuilder builder(String taskId) { - // Return builder with callback that removes queue from map when closed - return EventQueue.builder() - .taskId(taskId) - .addOnCloseCallback(getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Delegate to the base builder creation method + return createBaseEventQueueBuilder(taskId); } } } diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBus.java b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java new file mode 100644 index 000000000..90080b1e2 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java @@ -0,0 +1,42 @@ +package io.a2a.server.events; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class MainEventBus { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBus.class); + private final BlockingQueue queue; + + public MainEventBus() { + this.queue = new LinkedBlockingDeque<>(); + } + + void submit(String taskId, EventQueue.MainQueue mainQueue, EventQueueItem item) { + try { + queue.put(new MainEventBusContext(taskId, mainQueue, item)); + LOGGER.debug("Submitted event for task {} to MainEventBus (queue size: {})", + taskId, queue.size()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted submitting to MainEventBus", e); + } + } + + MainEventBusContext take() throws InterruptedException { + LOGGER.debug("MainEventBus: Waiting to take event (current queue size: {})...", queue.size()); + MainEventBusContext context = queue.take(); + LOGGER.debug("MainEventBus: Took event for task {} (remaining queue size: {})", + context.taskId(), queue.size()); + return context; + } + + public int size() { + return queue.size(); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java new file mode 100644 index 000000000..292a60f21 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +import java.util.Objects; + +record MainEventBusContext(String taskId, EventQueue.MainQueue eventQueue, EventQueueItem eventQueueItem) { + MainEventBusContext { + Objects.requireNonNull(taskId, "taskId cannot be null"); + Objects.requireNonNull(eventQueue, "eventQueue cannot be null"); + Objects.requireNonNull(eventQueueItem, "eventQueueItem cannot be null"); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java new file mode 100644 index 000000000..8b3dc6fa3 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -0,0 +1,386 @@ +package io.a2a.server.events; + +import java.util.concurrent.CompletableFuture; + +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.tasks.PushNotificationSender; +import io.a2a.server.tasks.TaskManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Event; +import io.a2a.spec.InternalError; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Background processor for the MainEventBus. + *

+ * This processor runs in a dedicated background thread, consuming events from the MainEventBus + * and performing two critical operations in order: + *

+ *
    + *
  1. Update TaskStore with event data (persistence FIRST)
  2. + *
  3. Distribute event to ChildQueues (clients see it AFTER persistence)
  4. + *
+ *

+ * This architecture ensures clients never receive events before they're persisted, + * eliminating race conditions and enabling reliable event replay. + *

+ *

+ * Note: This bean is eagerly initialized by {@link MainEventBusProcessorInitializer} + * to ensure the background thread starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessor implements Runnable { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessor.class); + + /** + * Callback for testing synchronization with async event processing. + * Default is NOOP to avoid null checks in production code. + * Tests can inject their own callback via setCallback(). + */ + private volatile MainEventBusProcessorCallback callback = MainEventBusProcessorCallback.NOOP; + + /** + * Optional executor for push notifications. + * If null, uses default ForkJoinPool (async). + * Tests can inject a synchronous executor to ensure deterministic ordering. + */ + private volatile @Nullable java.util.concurrent.Executor pushNotificationExecutor = null; + + private final MainEventBus eventBus; + + private final TaskStore taskStore; + + private final PushNotificationSender pushSender; + + private final QueueManager queueManager; + + private volatile boolean running = true; + private @Nullable Thread processorThread; + + @Inject + public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender, QueueManager queueManager) { + this.eventBus = eventBus; + this.taskStore = taskStore; + this.pushSender = pushSender; + this.queueManager = queueManager; + } + + /** + * Set a callback for testing synchronization with async event processing. + *

+ * This is primarily intended for tests that need to wait for event processing to complete. + * Pass null to reset to the default NOOP callback. + *

+ * + * @param callback the callback to invoke during event processing, or null for NOOP + */ + public void setCallback(MainEventBusProcessorCallback callback) { + this.callback = callback != null ? callback : MainEventBusProcessorCallback.NOOP; + } + + /** + * Set a custom executor for push notifications (primarily for testing). + *

+ * By default, push notifications are sent asynchronously using CompletableFuture.runAsync() + * with the default ForkJoinPool. For tests that need deterministic ordering of push + * notifications, inject a synchronous executor that runs tasks immediately on the calling thread. + *

+ * Example synchronous executor for tests: + *
{@code
+     * Executor syncExecutor = Runnable::run;
+     * mainEventBusProcessor.setPushNotificationExecutor(syncExecutor);
+     * }
+ * + * @param executor the executor to use for push notifications, or null to use default ForkJoinPool + */ + public void setPushNotificationExecutor(java.util.concurrent.Executor executor) { + this.pushNotificationExecutor = executor; + } + + @PostConstruct + void start() { + processorThread = new Thread(this, "MainEventBusProcessor"); + processorThread.setDaemon(true); // Allow JVM to exit even if this thread is running + processorThread.start(); + LOGGER.info("MainEventBusProcessor started"); + } + + /** + * No-op method to force CDI proxy resolution and ensure @PostConstruct has been called. + * Called by MainEventBusProcessorInitializer during application startup. + */ + public void ensureStarted() { + // Method intentionally empty - just forces proxy resolution + } + + @PreDestroy + void stop() { + LOGGER.info("MainEventBusProcessor stopping..."); + running = false; + if (processorThread != null) { + processorThread.interrupt(); + try { + long start = System.currentTimeMillis(); + processorThread.join(5000); // Wait up to 5 seconds + long elapsed = System.currentTimeMillis() - start; + LOGGER.info("MainEventBusProcessor thread stopped in {}ms", elapsed); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.warn("Interrupted while waiting for MainEventBusProcessor thread to stop"); + } + } + LOGGER.info("MainEventBusProcessor stopped"); + } + + @Override + public void run() { + LOGGER.info("MainEventBusProcessor processing loop started"); + while (running) { + try { + LOGGER.debug("MainEventBusProcessor: Waiting for event from MainEventBus..."); + MainEventBusContext context = eventBus.take(); + LOGGER.debug("MainEventBusProcessor: Retrieved event for task {} from MainEventBus", + context.taskId()); + processEvent(context); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.info("MainEventBusProcessor interrupted, shutting down"); + break; + } catch (Exception e) { + LOGGER.error("Error processing event from MainEventBus", e); + // Continue processing despite errors + } + } + LOGGER.info("MainEventBusProcessor processing loop ended"); + } + + private void processEvent(MainEventBusContext context) { + String taskId = context.taskId(); + Event event = context.eventQueueItem().getEvent(); + // MainEventBus.submit() guarantees this is always a MainQueue + EventQueue.MainQueue mainQueue = (EventQueue.MainQueue) context.eventQueue(); + + LOGGER.debug("MainEventBusProcessor: Processing event for task {}: {}", + taskId, event.getClass().getSimpleName()); + + Event eventToDistribute = null; + boolean isReplicated = context.eventQueueItem().isReplicated(); + try { + // Step 1: Update TaskStore FIRST (persistence before clients see it) + // If this throws, we distribute an error to ensure "persist before client visibility" + + try { + boolean isFinal = updateTaskStore(taskId, event, isReplicated); + + eventToDistribute = event; // Success - distribute original event + + // Trigger replication AFTER successful persistence + // SKIP replication if task is final - ReplicatedQueueManager handles this via TaskFinalizedEvent + // to ensure final Task is sent before poison pill (QueueClosedEvent) + if (!isFinal) { + EventEnqueueHook hook = mainQueue.getEnqueueHook(); + if (hook != null) { + LOGGER.debug("Triggering replication hook for task {} after successful persistence", taskId); + hook.onEnqueue(context.eventQueueItem()); + } + } else { + LOGGER.debug("Task {} is final - skipping replication hook (handled by ReplicatedQueueManager)", taskId); + } + } catch (InternalError e) { + // Persistence failed - create error event to distribute instead + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = e; + } catch (Exception e) { + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = new InternalError(errorMessage); + } + + // Step 2: Send push notification AFTER successful persistence (only from active node) + // Skip push notifications for replicated events to avoid duplicate notifications in multi-instance deployments + if (eventToDistribute == event && !isReplicated) { + // Capture task state immediately after persistence, before going async + // This ensures we send the task as it existed when THIS event was processed, + // not whatever state might exist later when the async callback executes + Task taskSnapshot = taskStore.get(taskId); + if (taskSnapshot != null) { + sendPushNotification(taskId, taskSnapshot); + } else { + LOGGER.warn("Task {} not found in TaskStore after successful persistence, skipping push notification", taskId); + } + } + + // Step 3: Then distribute to ChildQueues (clients see either event or error AFTER persistence attempt) + if (eventToDistribute == null) { + LOGGER.error("MainEventBusProcessor: eventToDistribute is NULL for task {} - this should never happen!", taskId); + eventToDistribute = new InternalError("Internal error: event processing failed"); + } + + int childCount = mainQueue.getChildCount(); + LOGGER.debug("MainEventBusProcessor: Distributing {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + // Create new EventQueueItem with the event to distribute (original or error) + EventQueueItem itemToDistribute = new LocalEventQueueItem(eventToDistribute); + mainQueue.distributeToChildren(itemToDistribute); + LOGGER.debug("MainEventBusProcessor: Distributed {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + + LOGGER.debug("MainEventBusProcessor: Completed processing event for task {}", taskId); + + } finally { + try { + // Step 4: Notify callback after all processing is complete + // Call callback with the distributed event (original or error) + if (eventToDistribute != null) { + callback.onEventProcessed(taskId, eventToDistribute); + + // Step 5: If this is a final event, notify task finalization + // Only for successful persistence (not for errors) + if (eventToDistribute == event && isFinalEvent(event)) { + callback.onTaskFinalized(taskId); + } + } + } finally { + // ALWAYS release semaphore, even if processing fails + // Balances the acquire() in MainQueue.enqueueEvent() + mainQueue.releaseSemaphore(); + } + } + } + + /** + * Updates TaskStore using TaskManager.process(). + *

+ * Creates a temporary TaskManager instance for this event and delegates to its process() method, + * which handles all event types (Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent). + * This leverages existing TaskManager logic for status updates, artifact appending, message history, etc. + *

+ *

+ * If persistence fails, the exception is propagated to processEvent() which distributes an + * InternalError to clients instead of the original event, ensuring "persist before visibility". + * See Gemini's comment: https://github.com/a2aproject/a2a-java/pull/515#discussion_r2604621833 + *

+ * + * @param taskId the task ID + * @param event the event to persist + * @return true if the task reached a final state, false otherwise + * @throws InternalError if persistence fails + */ + private boolean updateTaskStore(String taskId, Event event, boolean isReplicated) throws InternalError { + try { + // Extract contextId from event (all relevant events have it) + String contextId = extractContextId(event); + + // Create temporary TaskManager instance for this event + TaskManager taskManager = new TaskManager(taskId, contextId, taskStore, null); + + // Use TaskManager.process() - handles all event types with existing logic + boolean isFinal = taskManager.process(event, isReplicated); + LOGGER.debug("TaskStore updated via TaskManager.process() for task {}: {} (final: {}, replicated: {})", + taskId, event.getClass().getSimpleName(), isFinal, isReplicated); + return isFinal; + } catch (InternalError e) { + LOGGER.error("Error updating TaskStore via TaskManager for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw e; + } catch (Exception e) { + LOGGER.error("Unexpected error updating TaskStore for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw new InternalError("TaskStore persistence failed: " + e.getMessage()); + } + } + + /** + * Sends push notification for the task AFTER persistence. + *

+ * This is called after updateTaskStore() to ensure the notification contains + * the latest persisted state, avoiding race conditions. + *

+ *

+ * CRITICAL: Push notifications are sent asynchronously in the background + * to avoid blocking event distribution to ChildQueues. The 83ms overhead from + * PushNotificationSender.sendNotification() was causing streaming delays. + *

+ *

+ * IMPORTANT: The task parameter is a snapshot captured immediately after + * persistence. This ensures we send the task state as it existed when THIS event + * was processed, not whatever state might exist in TaskStore when the async + * callback executes (subsequent events may have already updated the store). + *

+ *

+ * NOTE: Tests can inject a synchronous executor via setPushNotificationExecutor() + * to ensure deterministic ordering of push notifications in the test environment. + *

+ * + * @param taskId the task ID + * @param task the task snapshot to send (captured immediately after persistence) + */ + private void sendPushNotification(String taskId, Task task) { + Runnable pushTask = () -> { + try { + if (task != null) { + LOGGER.debug("Sending push notification for task {}", taskId); + pushSender.sendNotification(task); + } else { + LOGGER.debug("Skipping push notification - task snapshot is null for task {}", taskId); + } + } catch (Exception e) { + LOGGER.error("Error sending push notification for task {}", taskId, e); + // Don't rethrow - push notifications are best-effort + } + }; + + // Use custom executor if set (for tests), otherwise use default ForkJoinPool (async) + if (pushNotificationExecutor != null) { + pushNotificationExecutor.execute(pushTask); + } else { + CompletableFuture.runAsync(pushTask); + } + } + + /** + * Extracts contextId from an event. + * Returns null if the event type doesn't have a contextId (e.g., Message). + */ + @Nullable + private String extractContextId(Event event) { + if (event instanceof Task task) { + return task.contextId(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.contextId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.contextId(); + } + // Message and other events don't have contextId + return null; + } + + /** + * Checks if an event represents a final task state. + * + * @param event the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); + } + return false; + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java new file mode 100644 index 000000000..b0a9adbce --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java @@ -0,0 +1,66 @@ +package io.a2a.server.events; + +import io.a2a.spec.Event; + +/** + * Callback interface for MainEventBusProcessor events. + *

+ * This interface is primarily intended for testing, allowing tests to synchronize + * with the asynchronous MainEventBusProcessor. Production code should not rely on this. + *

+ * Usage in tests: + *
+ * {@code
+ * @Inject
+ * MainEventBusProcessor processor;
+ *
+ * @BeforeEach
+ * void setUp() {
+ *     CountDownLatch latch = new CountDownLatch(3);
+ *     processor.setCallback(new MainEventBusProcessorCallback() {
+ *         public void onEventProcessed(String taskId, Event event) {
+ *             latch.countDown();
+ *         }
+ *     });
+ * }
+ *
+ * @AfterEach
+ * void tearDown() {
+ *     processor.setCallback(null); // Reset to NOOP
+ * }
+ * }
+ * 
+ */ +public interface MainEventBusProcessorCallback { + + /** + * Called after an event has been fully processed (persisted, notification sent, distributed to children). + * + * @param taskId the task ID + * @param event the event that was processed + */ + void onEventProcessed(String taskId, Event event); + + /** + * Called when a task reaches a final state (COMPLETED, FAILED, CANCELED, REJECTED). + * + * @param taskId the task ID that was finalized + */ + void onTaskFinalized(String taskId); + + /** + * No-op implementation that does nothing. + * Used as the default callback to avoid null checks. + */ + MainEventBusProcessorCallback NOOP = new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + // No-op + } + + @Override + public void onTaskFinalized(String taskId) { + // No-op + } + }; +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java new file mode 100644 index 000000000..ba4b300be --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java @@ -0,0 +1,43 @@ +package io.a2a.server.events; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.event.Observes; +import jakarta.inject.Inject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Portable CDI initializer for MainEventBusProcessor. + *

+ * This bean observes the ApplicationScoped initialization event and injects + * MainEventBusProcessor, which triggers its eager creation and starts the background thread. + *

+ *

+ * This approach is portable across all Jakarta CDI implementations (Weld, OpenWebBeans, Quarkus, etc.) + * and ensures MainEventBusProcessor starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessorInitializer { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessorInitializer.class); + + @Inject + MainEventBusProcessor processor; + + /** + * Observes ApplicationScoped initialization to force eager creation of MainEventBusProcessor. + * The injection of MainEventBusProcessor in this bean triggers its creation, and calling + * ensureStarted() forces the CDI proxy to be resolved, which ensures @PostConstruct has been + * called and the background thread is running. + */ + void onStart(@Observes @Initialized(ApplicationScoped.class) Object event) { + if (processor != null) { + // Force proxy resolution to ensure @PostConstruct has been called + processor.ensureStarted(); + LOGGER.info("MainEventBusProcessor initialized and started"); + } else { + LOGGER.error("MainEventBusProcessor is null - initialization failed!"); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/QueueManager.java b/server-common/src/main/java/io/a2a/server/events/QueueManager.java index 01e754fcb..4ad30f0cb 100644 --- a/server-common/src/main/java/io/a2a/server/events/QueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/QueueManager.java @@ -177,7 +177,31 @@ public interface QueueManager { * @return a builder for creating event queues */ default EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { - return EventQueue.builder(); + throw new UnsupportedOperationException( + "QueueManager implementations must override getEventQueueBuilder() to provide MainEventBus" + ); + } + + /** + * Creates a base EventQueueBuilder with standard configuration for this QueueManager. + * This method provides the foundation for creating event queues with proper configuration + * (MainEventBus, TaskStateProvider, cleanup callbacks, etc.). + *

+ * QueueManager implementations that use custom factories can call this method directly + * to get the base builder without going through the factory (which could cause infinite + * recursion if the factory delegates back to getEventQueueBuilder()). + *

+ *

+ * Callers can then add additional configuration (hooks, callbacks) before building the queue. + *

+ * + * @param taskId the task ID for the queue + * @return a builder with base configuration specific to this QueueManager implementation + */ + default EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + throw new UnsupportedOperationException( + "QueueManager implementations must override createBaseEventQueueBuilder() to provide MainEventBus" + ); } /** diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 002acbafd..c476c8741 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -11,13 +11,13 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -35,13 +35,15 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.events.QueueManager; -import io.a2a.server.events.TaskQueueExistsException; import io.a2a.server.tasks.PushNotificationConfigStore; import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.ResultAggregator; import io.a2a.server.tasks.TaskManager; import io.a2a.server.tasks.TaskStore; +import io.a2a.server.util.async.EventConsumerExecutorProducer.EventConsumerExecutor; import io.a2a.server.util.async.Internal; import io.a2a.spec.A2AError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; @@ -64,6 +66,7 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.UnsupportedOperationError; import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; @@ -122,7 +125,6 @@ *
  • {@link EventConsumer} polls and processes events on Vert.x worker thread
  • *
  • Queue closes automatically on final event (COMPLETED/FAILED/CANCELED)
  • *
  • Cleanup waits for both agent execution AND event consumption to complete
  • - *
  • Background tasks tracked via {@link #trackBackgroundTask(CompletableFuture)}
  • * * *

    Threading Model

    @@ -179,6 +181,13 @@ public class DefaultRequestHandler implements RequestHandler { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultRequestHandler.class); + /** + * Separate logger for thread statistics diagnostic logging. + * This allows independent control of verbose thread pool monitoring without affecting + * general request handler logging. Enable with: logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG + */ + private static final Logger THREAD_STATS_LOGGER = LoggerFactory.getLogger("io.a2a.server.diagnostics.ThreadStats"); + private static final String A2A_BLOCKING_AGENT_TIMEOUT_SECONDS = "a2a.blocking.agent.timeout.seconds"; private static final String A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS = "a2a.blocking.consumption.timeout.seconds"; @@ -214,13 +223,14 @@ public class DefaultRequestHandler implements RequestHandler { private TaskStore taskStore; private QueueManager queueManager; private PushNotificationConfigStore pushConfigStore; - private PushNotificationSender pushSender; + private MainEventBusProcessor mainEventBusProcessor; private Supplier requestContextBuilder; private final ConcurrentMap> runningAgents = new ConcurrentHashMap<>(); - private final Set> backgroundTasks = ConcurrentHashMap.newKeySet(); + private Executor executor; + private Executor eventConsumerExecutor; /** * No-args constructor for CDI proxy creation. @@ -234,21 +244,25 @@ protected DefaultRequestHandler() { this.taskStore = null; this.queueManager = null; this.pushConfigStore = null; - this.pushSender = null; + this.mainEventBusProcessor = null; this.requestContextBuilder = null; this.executor = null; + this.eventConsumerExecutor = null; } @Inject public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, @Internal Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + @Internal Executor executor, + @EventConsumerExecutor Executor eventConsumerExecutor) { this.agentExecutor = agentExecutor; this.taskStore = taskStore; this.queueManager = queueManager; this.pushConfigStore = pushConfigStore; - this.pushSender = pushSender; + this.mainEventBusProcessor = mainEventBusProcessor; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; // TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder // implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and // I am unsure about the correct scope. @@ -264,16 +278,20 @@ void initConfig() { configProvider.getValue(A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS)); } + /** * For testing */ public static DefaultRequestHandler create(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + Executor executor, Executor eventConsumerExecutor) { DefaultRequestHandler handler = - new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, pushSender, executor); + new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, + mainEventBusProcessor, executor, eventConsumerExecutor); handler.agentCompletionTimeoutSeconds = 5; handler.consumptionCompletionTimeoutSeconds = 2; + return handler; } @@ -359,12 +377,9 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); - EventQueue queue = queueManager.tap(task.id()); - if (queue == null) { - queue = queueManager.getEventQueueBuilder(task.id()).build(); - } + EventQueue queue = queueManager.createOrTap(task.id()); agentExecutor.cancel( requestContextBuilder.get() .setTaskId(task.id()) @@ -395,28 +410,41 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws @Override public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws A2AError { LOGGER.debug("onMessageSend - task: {}; context {}", params.message().taskId(), params.message().contextId()); + + // Build MessageSendSetup which creates RequestContext with real taskId (auto-generated if needed) MessageSendSetup mss = initMessageSend(params, context); - String taskId = mss.requestContext.getTaskId(); - LOGGER.debug("Request context taskId: {}", taskId); + // Use the taskId from RequestContext for queue management (no temp ID needed!) + // RequestContext.build() guarantees taskId is non-null via checkOrGenerateTaskId() + String queueTaskId = java.util.Objects.requireNonNull( + mss.requestContext.getTaskId(), "TaskId must be non-null after RequestContext.build()"); + LOGGER.debug("Queue taskId: {}", queueTaskId); - if (taskId == null) { - throw new io.a2a.spec.InternalError("Task ID is null in onMessageSend"); - } - EventQueue queue = queueManager.createOrTap(taskId); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + // Create queue with real taskId (no tempId parameter needed) + EventQueue queue = queueManager.createOrTap(queueTaskId); + final java.util.concurrent.atomic.AtomicReference<@NonNull String> taskId = new java.util.concurrent.atomic.AtomicReference<>(queueTaskId); + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); + // Default to blocking=false per A2A spec (return after task creation) boolean blocking = params.configuration() != null && Boolean.TRUE.equals(params.configuration().blocking()); + // Log blocking behavior from client request + if (params.configuration() != null && params.configuration().blocking() != null) { + LOGGER.debug("DefaultRequestHandler: Client requested blocking={} for task {}", + params.configuration().blocking(), taskId.get()); + } else if (params.configuration() != null) { + LOGGER.debug("DefaultRequestHandler: Client sent configuration but blocking=null, using default blocking={} for task {}", blocking, taskId.get()); + } else { + LOGGER.debug("DefaultRequestHandler: Client sent no configuration, using default blocking={} for task {}", blocking, taskId.get()); + } + LOGGER.debug("DefaultRequestHandler: Final blocking decision: {} for task {}", blocking, taskId.get()); + boolean interruptedOrNonBlocking = false; - EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId, mss.requestContext, queue); + EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); ResultAggregator.EventTypeAndInterrupt etai = null; EventKind kind = null; // Declare outside try block so it's in scope for return try { - // Create callback for push notifications during background event processing - Runnable pushNotificationCallback = () -> sendPushNotification(taskId, resultAggregator); - EventConsumer consumer = new EventConsumer(queue); // This callback must be added before we start consuming. Otherwise, @@ -424,7 +452,7 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); // Get agent future before consuming (for blocking calls to wait for agent completion) - CompletableFuture agentFuture = runningAgents.get(taskId); + CompletableFuture agentFuture = runningAgents.get(queueTaskId); etai = resultAggregator.consumeAndBreakOnInterrupt(consumer, blocking); if (etai == null) { @@ -432,7 +460,8 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte throw new InternalError("No result"); } interruptedOrNonBlocking = etai.interrupted(); - LOGGER.debug("Was interrupted or non-blocking: {}", interruptedOrNonBlocking); + LOGGER.debug("DefaultRequestHandler: interruptedOrNonBlocking={} (blocking={}, eventType={})", + interruptedOrNonBlocking, blocking, kind != null ? kind.getClass().getSimpleName() : null); // For blocking calls that were interrupted (returned on first event), // wait for agent execution and event processing BEFORE returning to client. @@ -441,30 +470,36 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // during the consumption loop itself. kind = etai.eventType(); + // No ID switching needed - agent uses context.getTaskId() which is the same as queue key + // Store push notification config for newly created tasks (mirrors streaming logic) // Only for NEW tasks - existing tasks are handled by initMessageSend() if (mss.task() == null && kind instanceof Task createdTask && shouldAddPushInfo(params)) { - LOGGER.debug("Storing push notification config for new task {}", createdTask.id()); + LOGGER.debug("Storing push notification config for new task {} (original taskId from params: {})", + createdTask.id(), params.message().taskId()); pushConfigStore.setInfo(createdTask.id(), params.configuration().pushNotificationConfig()); } if (blocking && interruptedOrNonBlocking) { - // For blocking calls: ensure all events are processed before returning - // Order of operations is critical to avoid circular dependency: - // 1. Wait for agent to finish enqueueing events + // For blocking calls: ensure all consumed events are persisted to TaskStore before returning + // Order of operations is critical to avoid circular dependency and race conditions: + // 1. Wait for agent to finish enqueueing events (or timeout) // 2. Close the queue to signal consumption can complete // 3. Wait for consumption to finish processing events - // 4. Fetch final task state from TaskStore + // 4. (Implicit) MainEventBusProcessor persistence guarantee via consumption completion + // 5. Fetch current task state from TaskStore (includes all consumed & persisted events) + LOGGER.debug("DefaultRequestHandler: Entering blocking fire-and-forget handling for task {}", taskId.get()); try { // Step 1: Wait for agent to finish (with configurable timeout) if (agentFuture != null) { try { agentFuture.get(agentCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Agent completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent completed for task {}", taskId.get()); } catch (java.util.concurrent.TimeoutException e) { // Agent still running after timeout - that's fine, events already being processed - LOGGER.debug("Agent still running for task {} after {}s", taskId, agentCompletionTimeoutSeconds); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent still running for task {} after {}s timeout", + taskId.get(), agentCompletionTimeoutSeconds); } } @@ -472,55 +507,84 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // For fire-and-forget tasks, there's no final event, so we need to close the queue // This allows EventConsumer.consumeAll() to exit queue.close(false, false); // graceful close, don't notify parent yet - LOGGER.debug("Closed queue for task {} to allow consumption completion", taskId); + LOGGER.debug("DefaultRequestHandler: Step 2 - Closed queue for task {} to allow consumption completion", taskId.get()); // Step 3: Wait for consumption to complete (now that queue is closed) if (etai.consumptionFuture() != null) { etai.consumptionFuture().get(consumptionCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Consumption completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 3 - Consumption completed for task {}", taskId.get()); } + + // Step 4: Implicit guarantee of persistence via consumption completion + // We do NOT add an explicit wait for MainEventBusProcessor here because: + // 1. MainEventBusProcessor persists BEFORE distributing to ChildQueues + // 2. Step 3 (consumption completion) already guarantees all consumed events are persisted + // 3. Adding another explicit synchronization point would require exposing + // MainEventBusProcessor internals and blocking event loop threads + // + // Note: For fire-and-forget tasks, if the agent is still running after Step 1 timeout, + // it may enqueue additional events. These will be persisted asynchronously but won't + // be included in the task state returned to the client (already consumed in Step 3). + } catch (InterruptedException e) { Thread.currentThread().interrupt(); - String msg = String.format("Error waiting for task %s completion", taskId); + String msg = String.format("Error waiting for task %s completion", taskId.get()); LOGGER.warn(msg, e); throw new InternalError(msg); } catch (java.util.concurrent.ExecutionException e) { - String msg = String.format("Error during task %s execution", taskId); + String msg = String.format("Error during task %s execution", taskId.get()); LOGGER.warn(msg, e.getCause()); throw new InternalError(msg); - } catch (java.util.concurrent.TimeoutException e) { - String msg = String.format("Timeout waiting for consumption to complete for task %s", taskId); - LOGGER.warn(msg, taskId); + } catch (TimeoutException e) { + // Timeout from consumption future.get() - different from finalization timeout + String msg = String.format("Timeout waiting for task %s consumption", taskId.get()); + LOGGER.warn(msg, e); throw new InternalError(msg); } - // Step 4: Fetch the final task state from TaskStore (all events have been processed) - // taskId is guaranteed non-null here (checked earlier) - String nonNullTaskId = taskId; + // Step 5: Fetch the current task state from TaskStore + // All events consumed in Step 3 are guaranteed persisted (MainEventBusProcessor + // ordering: persist → distribute → consume). This returns the persisted state + // including all consumed events and artifacts. + String nonNullTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); Task updatedTask = taskStore.get(nonNullTaskId); if (updatedTask != null) { kind = updatedTask; - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Fetched final task for {} with state {} and {} artifacts", - nonNullTaskId, updatedTask.status().state(), - updatedTask.artifacts().size()); - } + LOGGER.debug("DefaultRequestHandler: Step 5 - Fetched current task for {} with state {} and {} artifacts", + taskId.get(), updatedTask.status().state(), + updatedTask.artifacts().size()); + } else { + LOGGER.warn("DefaultRequestHandler: Step 5 - Task {} not found in TaskStore!", taskId.get()); } } - if (kind instanceof Task taskResult && !taskId.equals(taskResult.id())) { + String finalTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (kind instanceof Task taskResult && !finalTaskId.equals(taskResult.id())) { throw new InternalError("Task ID mismatch in agent response"); } - - // Send push notification after initial return (for both blocking and non-blocking) - pushNotificationCallback.run(); } finally { + // For non-blocking calls: close ChildQueue IMMEDIATELY to free EventConsumer thread + // CRITICAL: Must use immediate=true to clear the local queue, otherwise EventConsumer + // continues polling until queue drains naturally, holding executor thread. + // Immediate close clears pending events and triggers EventQueueClosedException on next poll. + // Events continue flowing through MainQueue → MainEventBus → TaskStore. + if (!blocking && etai != null && etai.interrupted()) { + LOGGER.debug("DefaultRequestHandler: Non-blocking call in finally - closing ChildQueue IMMEDIATELY for task {} to free EventConsumer", taskId.get()); + queue.close(true); // immediate=true: clear queue and free EventConsumer + } + // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId, runningAgents.size()); + CompletableFuture agentFuture = runningAgents.remove(queueTaskId); + String cleanupTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", cleanupTaskId, runningAgents.size()); - // Track cleanup as background task to avoid blocking Vert.x threads + // Cleanup as background task to avoid blocking Vert.x threads // Pass the consumption future to ensure cleanup waits for background consumption to complete - trackBackgroundTask(cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, taskId, queue, false)); + cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, cleanupTaskId, queue, false) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for task {}", taskId.get(), err); + } + }); } LOGGER.debug("Returning: {}", kind); @@ -530,29 +594,44 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte @Override public Flow.Publisher onMessageSendStream( MessageSendParams params, ServerCallContext context) throws A2AError { - LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}; backgroundTasks: {}", - params.message().taskId(), params.message().contextId(), runningAgents.size(), backgroundTasks.size()); + LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}", + params.message().taskId(), params.message().contextId(), runningAgents.size()); + + // Build MessageSendSetup which creates RequestContext with real taskId (auto-generated if needed) MessageSendSetup mss = initMessageSend(params, context); - @Nullable String initialTaskId = mss.requestContext.getTaskId(); - // For streaming, taskId can be null initially (will be set when Task event arrives) - // Use a temporary ID for queue creation if needed - String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); + // Use the taskId from RequestContext for queue management (no temp ID needed!) + // RequestContext.build() guarantees taskId is non-null via checkOrGenerateTaskId() + String queueTaskId = java.util.Objects.requireNonNull( + mss.requestContext.getTaskId(), "TaskId must be non-null after RequestContext.build()"); + final AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); - AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); - @SuppressWarnings("NullAway") - EventQueue queue = queueManager.createOrTap(taskId.get()); + // Create queue with real taskId (no tempId parameter needed) + EventQueue queue = queueManager.createOrTap(queueTaskId); LOGGER.debug("Created/tapped queue for task {}: {}", taskId.get(), queue); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + + // Store push notification config SYNCHRONOUSLY for new tasks before agent starts + // This ensures config is available when MainEventBusProcessor sends push notifications + // For existing tasks, config is stored in initMessageSend() + if (mss.task() == null && shouldAddPushInfo(params)) { + // Satisfy Nullaway + Objects.requireNonNull(taskId.get(), "taskId was null"); + LOGGER.debug("Storing push notification config for new streaming task {} EARLY (original taskId from params: {})", + taskId.get(), params.message().taskId()); + pushConfigStore.setInfo(taskId.get(), params.configuration().pushNotificationConfig()); + } + + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); // Move consumer creation and callback registration outside try block - // so consumer is available for background consumption on client disconnect EventConsumer consumer = new EventConsumer(queue); producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); - AtomicBoolean backgroundConsumeStarted = new AtomicBoolean(false); + // Store cancel callback in context for closeHandler to access + // When client disconnects, closeHandler can call this to stop EventConsumer polling loop + context.setEventConsumerCancelCallback(consumer::cancel); try { Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); @@ -562,36 +641,13 @@ public Flow.Publisher onMessageSendStream( processor(createTubeConfig(), results, ((errorConsumer, item) -> { Event event = item.getEvent(); if (event instanceof Task createdTask) { - if (!Objects.equals(taskId.get(), createdTask.id())) { - errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); - } - - // TODO the Python implementation no longer has the following block but removing it causes - // failures here - try { - queueManager.add(createdTask.id(), queue); - taskId.set(createdTask.id()); - } catch (TaskQueueExistsException e) { - // TODO Log - } - if (pushConfigStore != null && - params.configuration() != null && - params.configuration().pushNotificationConfig() != null) { - - pushConfigStore.setInfo( - createdTask.id(), - params.configuration().pushNotificationConfig()); - } - - } - String currentTaskId = taskId.get(); - if (pushSender != null && currentTaskId != null) { - EventKind latest = resultAggregator.getCurrentResult(); - if (latest instanceof Task latestTask) { - pushSender.sendNotification(latestTask); + // Verify task ID matches (should always match now - agent uses context.getTaskId()) + String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!currentId.equals(createdTask.id())) { + errorConsumer.accept(new InternalError("Task ID mismatch: expected " + currentId + + " but got " + createdTask.id())); } } - return true; })); @@ -600,7 +656,8 @@ public Flow.Publisher onMessageSendStream( Flow.Publisher finalPublisher = convertingProcessor(eventPublisher, event -> (StreamingEventKind) event); - // Wrap publisher to detect client disconnect and continue background consumption + // Wrap publisher to detect client disconnect and immediately close ChildQueue + // This prevents ChildQueue backpressure from blocking MainEventBusProcessor return subscriber -> { String currentTaskId = taskId.get(); LOGGER.debug("Creating subscription wrapper for task {}", currentTaskId); @@ -621,8 +678,10 @@ public void request(long n) { @Override public void cancel() { - LOGGER.debug("Client cancelled subscription for task {}, starting background consumption", taskId.get()); - startBackgroundConsumption(); + LOGGER.debug("Client cancelled subscription for task {}, closing ChildQueue immediately", taskId.get()); + // Close ChildQueue immediately to prevent backpressure + // (clears queue and releases semaphore permits) + queue.close(true); // immediate=true subscription.cancel(); } }); @@ -647,8 +706,8 @@ public void onComplete() { subscriber.onComplete(); } catch (IllegalStateException e) { // Client already disconnected and response closed - this is expected - // for streaming responses where client disconnect triggers background - // consumption. Log and ignore. + // for streaming responses where client disconnect closes ChildQueue. + // Log and ignore. if (e.getMessage() != null && e.getMessage().contains("Response has already been written")) { LOGGER.debug("Client disconnected before onComplete, response already closed for task {}", taskId.get()); } else { @@ -656,36 +715,26 @@ public void onComplete() { } } } - - private void startBackgroundConsumption() { - if (backgroundConsumeStarted.compareAndSet(false, true)) { - LOGGER.debug("Starting background consumption for task {}", taskId.get()); - // Client disconnected: continue consuming and persisting events in background - CompletableFuture bgTask = CompletableFuture.runAsync(() -> { - try { - LOGGER.debug("Background consumption thread started for task {}", taskId.get()); - resultAggregator.consumeAll(consumer); - LOGGER.debug("Background consumption completed for task {}", taskId.get()); - } catch (Exception e) { - LOGGER.error("Error during background consumption for task {}", taskId.get(), e); - } - }, executor); - trackBackgroundTask(bgTask); - } else { - LOGGER.debug("Background consumption already started for task {}", taskId.get()); - } - } }); }; } finally { - LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}; backgroundTasks: {}", - taskId.get(), runningAgents.size(), backgroundTasks.size()); - - // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId.get()); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); - - trackBackgroundTask(cleanupProducer(agentFuture, null, Objects.requireNonNull(taskId.get()), queue, true)); + // Needed to satisfy Nullaway + String idOfTask = taskId.get(); + if (idOfTask != null) { + LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}", + idOfTask, runningAgents.size()); + + // Remove agent from map immediately to prevent accumulation + CompletableFuture agentFuture = runningAgents.remove(idOfTask); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); + + cleanupProducer(agentFuture, null, idOfTask, queue, true) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for streaming task {}", taskId.get(), err); + } + }); + } } } @@ -746,7 +795,7 @@ public Flow.Publisher onResubscribeToTask( } TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); EventQueue queue = queueManager.tap(task.id()); LOGGER.debug("onResubscribeToTask - tapped queue: {}", queue != null ? System.identityHashCode(queue) : "null"); @@ -761,6 +810,13 @@ public Flow.Publisher onResubscribeToTask( queue = queueManager.createOrTap(task.id()); } + // Per A2A Protocol Spec 3.1.6 (Subscribe to Task): + // "The operation MUST return a Task object as the first event in the stream, + // representing the current state of the task at the time of subscription." + // Enqueue the current task state directly to this ChildQueue only (already persisted, no need for MainEventBus) + queue.enqueueEventLocalOnly(task); + LOGGER.debug("onResubscribeToTask - enqueued current task state as first event for taskId: {}", params.id()); + EventConsumer consumer = new EventConsumer(queue); Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); LOGGER.debug("onResubscribeToTask - returning publisher for taskId: {}", params.id()); @@ -819,8 +875,7 @@ public void run() { LOGGER.debug("Agent execution starting for task {}", taskId); agentExecutor.execute(requestContext, queue); LOGGER.debug("Agent execution completed for task {}", taskId); - // No longer wait for queue poller to start - the consumer (which is guaranteed - // to be running on the Vert.x worker thread) will handle queue lifecycle. + // The consumer (running on the Vert.x worker thread) handles queue lifecycle. // This avoids blocking agent-executor threads waiting for worker threads. } }; @@ -833,8 +888,8 @@ public void run() { // Don't close queue here - let the consumer handle it via error callback // This ensures the consumer (which may not have started polling yet) gets the error } - // Queue lifecycle is now managed entirely by EventConsumer.consumeAll() - // which closes the queue on final events. No need to close here. + // Queue lifecycle is managed by EventConsumer.consumeAll() + // which closes the queue on final events. logThreadStats("AGENT COMPLETE END"); runnable.invokeDoneCallbacks(); }); @@ -843,47 +898,6 @@ public void run() { return runnable; } - private void trackBackgroundTask(CompletableFuture task) { - backgroundTasks.add(task); - LOGGER.debug("Tracking background task (total: {}): {}", backgroundTasks.size(), task); - - task.whenComplete((result, throwable) -> { - try { - if (throwable != null) { - // Unwrap CompletionException to check for CancellationException - Throwable cause = throwable; - if (throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null) { - cause = throwable.getCause(); - } - - if (cause instanceof java.util.concurrent.CancellationException) { - LOGGER.debug("Background task cancelled: {}", task); - } else { - LOGGER.error("Background task failed", throwable); - } - } - } finally { - backgroundTasks.remove(task); - LOGGER.debug("Removed background task (remaining: {}): {}", backgroundTasks.size(), task); - } - }); - } - - /** - * Wait for all background tasks to complete. - * Useful for testing to ensure cleanup completes before assertions. - * - * @return CompletableFuture that completes when all background tasks finish - */ - public CompletableFuture waitForBackgroundTasks() { - CompletableFuture[] tasks = backgroundTasks.toArray(new CompletableFuture[0]); - if (tasks.length == 0) { - return CompletableFuture.completedFuture(null); - } - LOGGER.debug("Waiting for {} background tasks to complete", tasks.length); - return CompletableFuture.allOf(tasks); - } - private CompletableFuture cleanupProducer(@Nullable CompletableFuture agentFuture, @Nullable CompletableFuture consumptionFuture, String taskId, EventQueue queue, boolean isStreaming) { LOGGER.debug("Starting cleanup for task {} (streaming={})", taskId, isStreaming); logThreadStats("CLEANUP START"); @@ -908,14 +922,20 @@ private CompletableFuture cleanupProducer(@Nullable CompletableFuture cleanupProducer(@Nullable CompletableFuture + * Enable independently with: {@code logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG} + *

    */ @SuppressWarnings("unused") // Used for debugging private void logThreadStats(String label) { // Early return if debug logging is not enabled to avoid overhead - if (!LOGGER.isDebugEnabled()) { + if (!THREAD_STATS_LOGGER.isDebugEnabled()) { return; } @@ -982,28 +1036,57 @@ private void logThreadStats(String label) { } int activeThreads = rootGroup.activeCount(); - LOGGER.debug("=== THREAD STATS: {} ===", label); - LOGGER.debug("Active threads: {}", activeThreads); - LOGGER.debug("Running agents: {}", runningAgents.size()); - LOGGER.debug("Background tasks: {}", backgroundTasks.size()); - LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); + // Count specific thread types + Thread[] threads = new Thread[activeThreads * 2]; + int count = rootGroup.enumerate(threads); + int eventConsumerThreads = 0; + int agentExecutorThreads = 0; + for (int i = 0; i < count; i++) { + if (threads[i] != null) { + String name = threads[i].getName(); + if (name.startsWith("a2a-event-consumer-")) { + eventConsumerThreads++; + } else if (name.startsWith("a2a-agent-executor-")) { + agentExecutorThreads++; + } + } + } + + THREAD_STATS_LOGGER.debug("=== THREAD STATS: {} ===", label); + THREAD_STATS_LOGGER.debug("Total active threads: {}", activeThreads); + THREAD_STATS_LOGGER.debug("EventConsumer threads: {}", eventConsumerThreads); + THREAD_STATS_LOGGER.debug("AgentExecutor threads: {}", agentExecutorThreads); + THREAD_STATS_LOGGER.debug("Running agents: {}", runningAgents.size()); + THREAD_STATS_LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); // List running agents if (!runningAgents.isEmpty()) { - LOGGER.debug("Running agent tasks:"); + THREAD_STATS_LOGGER.debug("Running agent tasks:"); runningAgents.forEach((taskId, future) -> - LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") + THREAD_STATS_LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") ); } - // List background tasks - if (!backgroundTasks.isEmpty()) { - LOGGER.debug("Background tasks:"); - backgroundTasks.forEach(task -> - LOGGER.debug(" - {}: {}", task, task.isDone() ? "DONE" : "RUNNING") - ); + THREAD_STATS_LOGGER.debug("=== END THREAD STATS ==="); + } + + /** + * Check if an event represents a final task state. + * + * @param eventKind the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(EventKind eventKind) { + if (!(eventKind instanceof Event event)) { + return false; + } + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); } - LOGGER.debug("=== END THREAD STATS ==="); + return false; } private record MessageSendSetup(TaskManager taskManager, @Nullable Task task, RequestContext requestContext) {} diff --git a/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java b/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java index 1e3ae1206..15f94d7e8 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java @@ -32,8 +32,9 @@ public class InMemoryTaskStore implements TaskStore, TaskStateProvider { private final ConcurrentMap tasks = new ConcurrentHashMap<>(); @Override - public void save(Task task) { + public void save(Task task, boolean isReplicated) { tasks.put(task.id(), task); + // InMemoryTaskStore doesn't fire TaskFinalizedEvent, so isReplicated is unused here } @Override diff --git a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java index 95684e199..506b3f3b6 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java +++ b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java @@ -14,9 +14,9 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueueItem; import io.a2a.spec.A2AError; -import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; import io.a2a.spec.EventKind; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; @@ -31,12 +31,14 @@ public class ResultAggregator { private final TaskManager taskManager; private final Executor executor; + private final Executor eventConsumerExecutor; private volatile @Nullable Message message; - public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor) { + public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor, Executor eventConsumerExecutor) { this.taskManager = taskManager; this.message = message; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; } public @Nullable EventKind getCurrentResult() { @@ -49,20 +51,23 @@ public ResultAggregator(TaskManager taskManager, @Nullable Message message, Exec public Flow.Publisher consumeAndEmit(EventConsumer consumer) { Flow.Publisher allItems = consumer.consumeAll(); - // Process items conditionally - only save non-replicated events to database - return processor(createTubeConfig(), allItems, (errorConsumer, item) -> { - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(item.getEvent()); - } catch (A2AServerException e) { - errorConsumer.accept(e); - return false; - } - } - // Continue processing and emit (both replicated and non-replicated) + // Just stream events - no persistence needed + // TaskStore update moved to MainEventBusProcessor + Flow.Publisher processed = processor(createTubeConfig(), allItems, (errorConsumer, item) -> { + // Continue processing and emit all events return true; }); + + // Wrap the publisher to ensure subscription happens on eventConsumerExecutor + // This prevents EventConsumer polling loop from running on AgentExecutor threads + // which caused thread accumulation when those threads didn't timeout + return new Flow.Publisher() { + @Override + public void subscribe(Flow.Subscriber subscriber) { + // Submit subscription to eventConsumerExecutor to isolate polling work + eventConsumerExecutor.execute(() -> processed.subscribe(subscriber)); + } + }; } public EventKind consumeAll(EventConsumer consumer) throws A2AError { @@ -81,15 +86,7 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { return false; } } - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - error.set(e); - return false; - } - } + // TaskStore update moved to MainEventBusProcessor return true; }, error::set); @@ -113,18 +110,24 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, boolean blocking) throws A2AError { Flow.Publisher allItems = consumer.consumeAll(); AtomicReference message = new AtomicReference<>(); + AtomicReference capturedTask = new AtomicReference<>(); // Capture Task events AtomicBoolean interrupted = new AtomicBoolean(false); AtomicReference errorRef = new AtomicReference<>(); CompletableFuture completionFuture = new CompletableFuture<>(); // Separate future for tracking background consumption completion CompletableFuture consumptionCompletionFuture = new CompletableFuture<>(); + // Latch to ensure EventConsumer starts polling before we wait on completionFuture + java.util.concurrent.CountDownLatch pollingStarted = new java.util.concurrent.CountDownLatch(1); // CRITICAL: The subscription itself must run on a background thread to avoid blocking // the Vert.x worker thread. EventConsumer.consumeAll() starts a polling loop that // blocks in dequeueEventItem(), so we must subscribe from a background thread. - // Use the @Internal executor (not ForkJoinPool.commonPool) to avoid saturation - // during concurrent request bursts. + // Use the dedicated @EventConsumerExecutor (cached thread pool) which creates threads + // on demand for I/O-bound polling. Using the @Internal executor caused deadlock when + // pool exhausted (100+ concurrent queues but maxPoolSize=50). CompletableFuture.runAsync(() -> { + // Signal that polling is about to start + pollingStarted.countDown(); consumer( createTubeConfig(), allItems, @@ -146,25 +149,30 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, return false; } - // Process event through TaskManager - only for non-replicated events - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - errorRef.set(e); - completionFuture.completeExceptionally(e); - return false; + // Capture Task events (especially for new tasks where taskManager.getTask() would return null) + // We capture the LATEST task to ensure we get the most up-to-date state + if (event instanceof Task t) { + Task previousTask = capturedTask.get(); + capturedTask.set(t); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Captured Task event: id={}, state={} (previous: {})", + t.id(), t.status().state(), + previousTask != null ? previousTask.id() + "/" + previousTask.status().state() : "none"); } } + // TaskStore update moved to MainEventBusProcessor + // Determine interrupt behavior boolean shouldInterrupt = false; - boolean continueInBackground = false; boolean isFinalEvent = (event instanceof Task task && task.status().state().isFinal()) || (event instanceof TaskStatusUpdateEvent tsue && tsue.isFinal()); boolean isAuthRequired = (event instanceof Task task && task.status().state() == TaskState.AUTH_REQUIRED) || (event instanceof TaskStatusUpdateEvent tsue && tsue.status().state() == TaskState.AUTH_REQUIRED); + LOGGER.debug("ResultAggregator: Evaluating interrupt (blocking={}, isFinal={}, isAuth={}, eventType={})", + blocking, isFinalEvent, isAuthRequired, event.getClass().getSimpleName()); + // Always interrupt on auth_required, as it needs external action. if (isAuthRequired) { // auth-required is a special state: the message should be @@ -174,20 +182,19 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, // new request is expected in order for the agent to make progress, // so the agent should exit. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (AUTH_REQUIRED)"); } else if (!blocking) { // For non-blocking calls, interrupt as soon as a task is available. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (non-blocking)"); } else if (blocking) { // For blocking calls: Interrupt to free Vert.x thread, but continue in background // Python's async consumption doesn't block threads, but Java's does // So we interrupt to return quickly, then rely on background consumption - // DefaultRequestHandler will fetch the final state from TaskStore shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (blocking, isFinal={})", isFinalEvent); if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocking call for task {}: {} event, returning with background consumption", taskIdForLogging(), isFinalEvent ? "final" : "non-final"); @@ -195,14 +202,14 @@ else if (blocking) { } if (shouldInterrupt) { + LOGGER.debug("ResultAggregator: Interrupting consumption (setting interrupted=true)"); // Complete the future to unblock the main thread interrupted.set(true); completionFuture.complete(null); // For blocking calls, DON'T complete consumptionCompletionFuture here. // Let it complete naturally when subscription finishes (onComplete callback below). - // This ensures all events are processed and persisted to TaskStore before - // DefaultRequestHandler.cleanupProducer() proceeds with cleanup. + // This ensures all events are fully processed before cleanup. // // For non-blocking and auth-required calls, complete immediately to allow // cleanup to proceed while consumption continues in background. @@ -237,7 +244,16 @@ else if (blocking) { } } ); - }, executor); + }, eventConsumerExecutor); + + // Wait for EventConsumer to start polling before we wait for events + // This prevents race where agent enqueues events before EventConsumer starts + try { + pollingStarted.await(5, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new io.a2a.spec.InternalError("Interrupted waiting for EventConsumer to start"); + } // Wait for completion or interruption try { @@ -261,28 +277,30 @@ else if (blocking) { Utils.rethrow(error); } - EventKind eventType; - Message msg = message.get(); - if (msg != null) { - eventType = msg; - } else { - Task task = taskManager.getTask(); - if (task == null) { - throw new io.a2a.spec.InternalError("No task or message available after consuming events"); + // Return Message if captured, otherwise Task if captured, otherwise fetch from TaskStore + EventKind eventKind = message.get(); + if (eventKind == null) { + eventKind = capturedTask.get(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning capturedTask: id={}, state={}", t.id(), t.status().state()); } - eventType = task; + } + if (eventKind == null) { + eventKind = taskManager.getTask(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning task from TaskStore: id={}, state={}", t.id(), t.status().state()); + } + } + if (eventKind == null) { + throw new InternalError("Could not find a Task/Message for " + taskManager.getTaskId()); } return new EventTypeAndInterrupt( - eventType, + eventKind, interrupted.get(), consumptionCompletionFuture); } - private void callTaskManagerProcess(Event event) throws A2AServerException { - taskManager.process(event); - } - private String taskIdForLogging() { Task task = taskManager.getTask(); return task != null ? task.id() : "unknown"; diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java index fd3696a60..948ec596c 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java @@ -12,7 +12,7 @@ import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; -import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; @@ -59,12 +59,13 @@ public TaskManager(@Nullable String taskId, @Nullable String contextId, TaskStor return currentTask; } - Task saveTaskEvent(Task task) throws A2AServerException { + boolean saveTaskEvent(Task task, boolean isReplicated) throws A2AServerException { checkIdsAndUpdateIfNecessary(task.id(), task.contextId()); - return saveTask(task); + Task savedTask = saveTask(task, isReplicated); + return savedTask.status() != null && savedTask.status().state() != null && savedTask.status().state().isFinal(); } - Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException { + boolean saveTaskEvent(TaskStatusUpdateEvent event, boolean isReplicated) throws A2AServerException { checkIdsAndUpdateIfNecessary(event.taskId(), event.contextId()); Task task = ensureTask(event.taskId(), event.contextId()); @@ -86,10 +87,11 @@ Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException { } task = builder.build(); - return saveTask(task); + Task savedTask = saveTask(task, isReplicated); + return savedTask.status() != null && savedTask.status().state() != null && savedTask.status().state().isFinal(); } - Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { + boolean saveTaskEvent(TaskArtifactUpdateEvent event, boolean isReplicated) throws A2AServerException { checkIdsAndUpdateIfNecessary(event.taskId(), event.contextId()); Task task = ensureTask(event.taskId(), event.contextId()); // taskId is guaranteed to be non-null after checkIdsAndUpdateIfNecessary @@ -98,18 +100,20 @@ Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { throw new IllegalStateException("taskId should not be null after checkIdsAndUpdateIfNecessary"); } task = appendArtifactToTask(task, event, nonNullTaskId); - return saveTask(task); + Task savedTask = saveTask(task, isReplicated); + return savedTask.status() != null && savedTask.status().state() != null && savedTask.status().state().isFinal(); } - public Event process(Event event) throws A2AServerException { + public boolean process(Event event, boolean isReplicated) throws A2AServerException { + boolean isFinal = false; if (event instanceof Task task) { - saveTaskEvent(task); + isFinal = saveTaskEvent(task, isReplicated); } else if (event instanceof TaskStatusUpdateEvent taskStatusUpdateEvent) { - saveTaskEvent(taskStatusUpdateEvent); + isFinal = saveTaskEvent(taskStatusUpdateEvent, isReplicated); } else if (event instanceof TaskArtifactUpdateEvent taskArtifactUpdateEvent) { - saveTaskEvent(taskArtifactUpdateEvent); + isFinal = saveTaskEvent(taskArtifactUpdateEvent, isReplicated); } - return event; + return isFinal; } public Task updateWithMessage(Message message, Task task) { @@ -125,7 +129,7 @@ public Task updateWithMessage(Message message, Task task) { .status(status) .history(history) .build(); - saveTask(task); + saveTask(task, false); // Local operation, not replicated return task; } @@ -133,7 +137,7 @@ private void checkIdsAndUpdateIfNecessary(String eventTaskId, String eventContex if (taskId != null && !eventTaskId.equals(taskId)) { throw new A2AServerException( "Invalid task id", - new InvalidParamsError(String.format("Task in event doesn't match TaskManager "))); + new InternalError(String.format("Task event has taskId %s but TaskManager has %s", eventTaskId, taskId))); } if (taskId == null) { taskId = eventTaskId; @@ -155,7 +159,7 @@ private Task ensureTask(String eventTaskId, String eventContextId) { } if (task == null) { task = createTask(eventTaskId, eventContextId); - saveTask(task); + saveTask(task, false); // Local operation, not replicated } return task; } @@ -170,8 +174,8 @@ private Task createTask(String taskId, String contextId) { .build(); } - private Task saveTask(Task task) { - taskStore.save(task); + private Task saveTask(Task task, boolean isReplicated) { + taskStore.save(task, isReplicated); if (taskId == null) { taskId = task.id(); contextId = task.contextId(); diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java b/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java index 18707fba2..3df903f77 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java @@ -100,8 +100,11 @@ public interface TaskStore { * Saves or updates a task. * * @param task the task to save + * @param isReplicated true if this task update came from a replicated event, + * false if it originated locally. Used to prevent feedback loops + * in replicated scenarios (e.g., don't fire TaskFinalizedEvent for replicated updates) */ - void save(Task task); + void save(Task task, boolean isReplicated); /** * Retrieves a task by its ID. diff --git a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java index e26dd55fb..eee254ba3 100644 --- a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java +++ b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java @@ -1,8 +1,8 @@ package io.a2a.server.util.async; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -26,6 +26,7 @@ public class AsyncExecutorProducer { private static final String A2A_EXECUTOR_CORE_POOL_SIZE = "a2a.executor.core-pool-size"; private static final String A2A_EXECUTOR_MAX_POOL_SIZE = "a2a.executor.max-pool-size"; private static final String A2A_EXECUTOR_KEEP_ALIVE_SECONDS = "a2a.executor.keep-alive-seconds"; + private static final String A2A_EXECUTOR_QUEUE_CAPACITY = "a2a.executor.queue-capacity"; @Inject A2AConfigProvider configProvider; @@ -57,6 +58,16 @@ public class AsyncExecutorProducer { */ long keepAliveSeconds; + /** + * Queue capacity for pending tasks. + *

    + * Property: {@code a2a.executor.queue-capacity}
    + * Default: 100
    + * Note: Must be bounded to allow pool growth to maxPoolSize. + * When queue is full, new threads are created up to maxPoolSize. + */ + int queueCapacity; + private @Nullable ExecutorService executor; @PostConstruct @@ -64,18 +75,34 @@ public void init() { corePoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_CORE_POOL_SIZE)); maxPoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_MAX_POOL_SIZE)); keepAliveSeconds = Long.parseLong(configProvider.getValue(A2A_EXECUTOR_KEEP_ALIVE_SECONDS)); - - LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}", - corePoolSize, maxPoolSize, keepAliveSeconds); - - executor = new ThreadPoolExecutor( + queueCapacity = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_QUEUE_CAPACITY)); + + LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}, queueCapacity={}", + corePoolSize, maxPoolSize, keepAliveSeconds, queueCapacity); + + // CRITICAL: Use ArrayBlockingQueue (bounded) instead of LinkedBlockingQueue (unbounded). + // With unbounded queue, ThreadPoolExecutor NEVER grows beyond corePoolSize because the + // queue never fills. This causes executor pool exhaustion during concurrent requests when + // EventConsumer polling threads hold all core threads and agent tasks queue indefinitely. + // Bounded queue enables pool growth: when queue is full, new threads are created up to + // maxPoolSize, preventing agent execution starvation. + ThreadPoolExecutor tpe = new ThreadPoolExecutor( corePoolSize, maxPoolSize, keepAliveSeconds, TimeUnit.SECONDS, - new LinkedBlockingQueue<>(), + new ArrayBlockingQueue<>(queueCapacity), new A2AThreadFactory() ); + + // CRITICAL: Allow core threads to timeout after keepAliveSeconds when idle. + // By default, ThreadPoolExecutor only times out threads above corePoolSize. + // Without this, core threads accumulate during testing and never clean up. + // This is essential for streaming scenarios where many short-lived tasks create threads + // for agent execution and cleanup callbacks, but those threads remain idle afterward. + tpe.allowCoreThreadTimeOut(true); + + executor = tpe; } @PreDestroy @@ -106,6 +133,22 @@ public Executor produce() { return executor; } + /** + * Log current executor pool statistics for diagnostics. + * Useful for debugging pool exhaustion or sizing issues. + */ + public void logPoolStats() { + if (executor instanceof ThreadPoolExecutor tpe) { + LOGGER.info("Executor pool stats: active={}/{}, queued={}/{}, completed={}, total={}", + tpe.getActiveCount(), + tpe.getPoolSize(), + tpe.getQueue().size(), + queueCapacity, + tpe.getCompletedTaskCount(), + tpe.getTaskCount()); + } + } + private static class A2AThreadFactory implements ThreadFactory { private final AtomicInteger threadNumber = new AtomicInteger(1); private final String namePrefix = "a2a-agent-executor-"; diff --git a/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java new file mode 100644 index 000000000..24ff7f5d1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java @@ -0,0 +1,93 @@ +package io.a2a.server.util.async; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; +import jakarta.inject.Qualifier; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Produces a dedicated executor for EventConsumer polling threads. + *

    + * CRITICAL: EventConsumer polling must use a separate executor from AgentExecutor because: + *

      + *
    • EventConsumer threads are I/O-bound (blocking on queue.poll()), not CPU-bound
    • + *
    • One EventConsumer thread needed per active queue (can be 100+ concurrent)
    • + *
    • Threads are mostly idle, waiting for events
    • + *
    • Using the same bounded pool as AgentExecutor causes deadlock when pool exhausted
    • + *
    + *

    + * Uses a cached thread pool (unbounded) with automatic thread reclamation: + *

      + *
    • Creates threads on demand as EventConsumers start
    • + *
    • Idle threads automatically terminated after 10 seconds
    • + *
    • No queue saturation since threads are created as needed
    • + *
    + */ +@ApplicationScoped +public class EventConsumerExecutorProducer { + private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumerExecutorProducer.class); + + /** + * Qualifier annotation for EventConsumer executor injection. + */ + @Retention(RUNTIME) + @Target({METHOD, FIELD, PARAMETER, TYPE}) + @Qualifier + public @interface EventConsumerExecutor { + } + + /** + * Thread factory for EventConsumer threads. + */ + private static class EventConsumerThreadFactory implements ThreadFactory { + private final AtomicInteger threadNumber = new AtomicInteger(1); + + @Override + public Thread newThread(Runnable r) { + Thread thread = new Thread(r, "a2a-event-consumer-" + threadNumber.getAndIncrement()); + thread.setDaemon(true); + return thread; + } + } + + private @Nullable ExecutorService executor; + + @Produces + @EventConsumerExecutor + @ApplicationScoped + public Executor eventConsumerExecutor() { + // Cached thread pool with 10s idle timeout (reduced from default 60s): + // - Creates threads on demand as EventConsumers start + // - Reclaims idle threads after 10s to prevent accumulation during fast test execution + // - Perfect for I/O-bound EventConsumer polling which blocks on queue.poll() + // - 10s timeout balances thread reuse (production) vs cleanup (testing) + executor = new ThreadPoolExecutor( + 0, // corePoolSize - no core threads + Integer.MAX_VALUE, // maxPoolSize - unbounded + 10, TimeUnit.SECONDS, // keepAliveTime - 10s idle timeout + new SynchronousQueue<>(), // queue - same as cached pool + new EventConsumerThreadFactory() + ); + + LOGGER.info("Initialized EventConsumer executor: cached thread pool (unbounded, 10s idle timeout)"); + + return executor; + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java new file mode 100644 index 000000000..737fbac23 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java @@ -0,0 +1,136 @@ +package io.a2a.server.util.sse; + +import io.a2a.grpc.utils.JSONRPCUtils; +import io.a2a.jsonrpc.common.wrappers.A2AErrorResponse; +import io.a2a.jsonrpc.common.wrappers.A2AResponse; +import io.a2a.jsonrpc.common.wrappers.CancelTaskResponse; +import io.a2a.jsonrpc.common.wrappers.DeleteTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetExtendedAgentCardResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; +import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; +import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; + +/** + * Framework-agnostic utility for formatting A2A responses as Server-Sent Events (SSE). + *

    + * Provides static methods to serialize A2A responses to JSON and format them as SSE events. + * This allows HTTP server frameworks (Vert.x, Jakarta/WildFly, etc.) to use their own + * reactive libraries for publisher mapping while sharing the serialization logic. + *

    + * Example usage (Quarkus/Vert.x with Mutiny): + *

    {@code
    + * Flow.Publisher> responses = handler.onMessageSendStream(request, context);
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Multi sseEvents = Multi.createFrom().publisher(responses)
    + *     .map(response -> SseFormatter.formatResponseAsSSE(response, eventId.getAndIncrement()));
    + *
    + * sseEvents.subscribe().with(sseEvent -> httpResponse.write(Buffer.buffer(sseEvent)));
    + * }
    + *

    + * Example usage (Jakarta/WildFly with custom reactive library): + *

    {@code
    + * Flow.Publisher jsonStrings = restHandler.getJsonPublisher();
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Flow.Publisher sseEvents = mapPublisher(jsonStrings,
    + *     json -> SseFormatter.formatJsonAsSSE(json, eventId.getAndIncrement()));
    + * }
    + */ +public class SseFormatter { + + private SseFormatter() { + // Utility class - prevent instantiation + } + + /** + * Format an A2A response as an SSE event. + *

    + * Serializes the response to JSON and formats as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + * + * @param response the A2A response to format + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatResponseAsSSE(A2AResponse response, long eventId) { + String jsonData = serializeResponse(response); + return "data: " + jsonData + "\nid: " + eventId + "\n\n"; + } + + /** + * Format a pre-serialized JSON string as an SSE event. + *

    + * Wraps the JSON in SSE format as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + *

    + * Use this when you already have JSON strings (e.g., from REST transport) + * and just need to add SSE formatting. + * + * @param jsonString the JSON string to wrap + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatJsonAsSSE(String jsonString, long eventId) { + return "data: " + jsonString + "\nid: " + eventId + "\n\n"; + } + + /** + * Serialize an A2AResponse to JSON string. + */ + private static String serializeResponse(A2AResponse response) { + // For error responses, use standard JSON-RPC error format + if (response instanceof A2AErrorResponse error) { + return JSONRPCUtils.toJsonRPCErrorResponse(error.getId(), error.getError()); + } + if (response.getError() != null) { + return JSONRPCUtils.toJsonRPCErrorResponse(response.getId(), response.getError()); + } + + // Convert domain response to protobuf message and serialize + com.google.protobuf.MessageOrBuilder protoMessage = convertToProto(response); + return JSONRPCUtils.toJsonRPCResultResponse(response.getId(), protoMessage); + } + + /** + * Convert A2AResponse to protobuf message for serialization. + */ + private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse response) { + if (response instanceof GetTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof CancelTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof SendMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessage(r.getResult()); + } else if (response instanceof ListTasksResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTasksResult(r.getResult()); + } else if (response instanceof SetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.setTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof GetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof ListTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof DeleteTaskPushNotificationConfigResponse) { + // DeleteTaskPushNotificationConfig has no result body, just return empty message + return com.google.protobuf.Empty.getDefaultInstance(); + } else if (response instanceof GetExtendedAgentCardResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getExtendedCardResponse(r.getResult()); + } else if (response instanceof SendStreamingMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessageStream(r.getResult()); + } else { + throw new IllegalArgumentException("Unknown response type: " + response.getClass().getName()); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/package-info.java b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java new file mode 100644 index 000000000..7e668b632 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java @@ -0,0 +1,11 @@ +/** + * Server-Sent Events (SSE) formatting utilities for A2A streaming responses. + *

    + * Provides framework-agnostic conversion of {@code Flow.Publisher>} to + * {@code Flow.Publisher} with SSE formatting, enabling easy integration with + * any HTTP server framework (Vert.x, Jakarta Servlet, etc.). + */ +@NullMarked +package io.a2a.server.util.sse; + +import org.jspecify.annotations.NullMarked; diff --git a/server-common/src/main/resources/META-INF/a2a-defaults.properties b/server-common/src/main/resources/META-INF/a2a-defaults.properties index 280fd943b..719be9e7a 100644 --- a/server-common/src/main/resources/META-INF/a2a-defaults.properties +++ b/server-common/src/main/resources/META-INF/a2a-defaults.properties @@ -19,3 +19,7 @@ a2a.executor.max-pool-size=50 # Keep-alive time for idle threads (seconds) a2a.executor.keep-alive-seconds=60 + +# Queue capacity for pending tasks (must be bounded to enable pool growth) +# When queue is full, new threads are created up to max-pool-size +a2a.executor.queue-capacity=100 diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index 4354f1639..146bfb10a 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -16,6 +16,8 @@ import java.util.concurrent.atomic.AtomicReference; import io.a2a.jsonrpc.common.json.JsonProcessingException; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.A2AServerException; import io.a2a.spec.Artifact; @@ -27,14 +29,19 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventConsumerTest { + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id + private EventQueue eventQueue; private EventConsumer eventConsumer; - + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private static final String MINIMAL_TASK = """ { @@ -54,10 +61,59 @@ public class EventConsumerTest { @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); eventConsumer = new EventConsumer(eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test public void testConsumeOneTaskEvent() throws Exception { Task event = fromJson(MINIMAL_TASK, Task.class); @@ -92,7 +148,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -100,7 +156,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -128,7 +184,7 @@ public void testConsumeUntilMessage() throws Exception { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -136,7 +192,7 @@ public void testConsumeUntilMessage() throws Exception { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -185,14 +241,14 @@ public void testConsumeMessageEvents() throws Exception { @Test public void testConsumeTaskInputRequired() { Task task = Task.builder() - .id("task-id") - .contextId("task-context") + .id(TASK_ID) + .contextId("session-xyz") .status(new TaskStatus(TaskState.INPUT_REQUIRED)) .build(); List events = List.of( task, TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -200,7 +256,7 @@ public void testConsumeTaskInputRequired() { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -332,7 +388,9 @@ public void onComplete() { @Test public void testConsumeAllStopsOnQueueClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Close the queue immediately @@ -378,12 +436,16 @@ public void onComplete() { @Test public void testConsumeAllHandlesQueueClosedException() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Add a message event (which will complete the stream) Event message = fromJson(MESSAGE_PAYLOAD, Message.class); - queue.enqueueEvent(message); + + // Use callback to wait for event processing + waitForEventProcessing(() -> queue.enqueueEvent(message)); // Close the queue before consuming queue.close(); @@ -428,11 +490,13 @@ public void onComplete() { @Test public void testConsumeAllTerminatesOnQueueClosedEvent() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Enqueue a QueueClosedEvent (poison pill) - QueueClosedEvent queueClosedEvent = new QueueClosedEvent("task-123"); + QueueClosedEvent queueClosedEvent = new QueueClosedEvent(TASK_ID); queue.enqueueEvent(queueClosedEvent); Flow.Publisher publisher = consumer.consumeAll(); @@ -477,8 +541,12 @@ public void onComplete() { } private void enqueueAndConsumeOneEvent(Event event) throws Exception { - eventQueue.enqueueEvent(event); + // Use callback to wait for event processing + waitForEventProcessing(() -> eventQueue.enqueueEvent(event)); + + // Event is now available, consume it directly Event result = eventConsumer.consumeOne(); + assertNotNull(result, "Event should be available"); assertSame(event, result); } diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java index a3dc7d916..2499a8173 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java @@ -11,7 +11,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.Artifact; import io.a2a.spec.Event; @@ -23,12 +27,17 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventQueueTest { private EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id private static final String MINIMAL_TASK = """ { @@ -46,38 +55,96 @@ public class EventQueueTest { } """; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); + } + + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + /** + * Helper to create a queue with MainEventBus configured (for tests that need event distribution). + */ + private EventQueue createQueueWithEventBus(String taskId) { + return EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(taskId) + .build(); + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } } @Test public void testConstructorDefaultQueueSize() { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); assertEquals(EventQueue.DEFAULT_QUEUE_SIZE, queue.getQueueSize()); } @Test public void testConstructorCustomQueueSize() { int customSize = 500; - EventQueue queue = EventQueue.builder().queueSize(customSize).build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(customSize).build(); assertEquals(customSize, queue.getQueueSize()); } @Test public void testConstructorInvalidQueueSize() { // Test zero queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(0).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(0).build()); // Test negative queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(-10).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(-10).build()); } @Test public void testTapCreatesChildQueue() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertNotNull(childQueue); @@ -87,7 +154,7 @@ public void testTapCreatesChildQueue() { @Test public void testTapOnChildQueueThrowsException() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertThrows(IllegalStateException.class, () -> childQueue.tap()); @@ -95,69 +162,74 @@ public void testTapOnChildQueueThrowsException() { @Test public void testEnqueueEventPropagagesToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Event should be available in both parent and child queues - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - Event childEvent = childQueue.dequeueEventItem(-1).getEvent(); + // Event should be available in all child queues + // Note: MainEventBusProcessor runs async, so we use dequeueEventItem with timeout + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); - assertSame(event, parentEvent); - assertSame(event, childEvent); + assertSame(event, child1Event); + assertSame(event, child2Event); } @Test public void testMultipleChildQueuesReceiveEvents() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event1 = fromJson(MINIMAL_TASK, Task.class); Event event2 = fromJson(MESSAGE_PAYLOAD, Message.class); - parentQueue.enqueueEvent(event1); - parentQueue.enqueueEvent(event2); + mainQueue.enqueueEvent(event1); + mainQueue.enqueueEvent(event2); - // All queues should receive both events - assertSame(event1, parentQueue.dequeueEventItem(-1).getEvent()); - assertSame(event2, parentQueue.dequeueEventItem(-1).getEvent()); + // All child queues should receive both events + // Note: Use timeout for async processing + assertSame(event1, childQueue1.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue1.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue1.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue1.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue2.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue2.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue2.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue2.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue3.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue3.dequeueEventItem(5000).getEvent()); } @Test public void testChildQueueDequeueIndependently() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Dequeue from child1 first - Event child1Event = childQueue1.dequeueEventItem(-1).getEvent(); + // Dequeue from child1 first (use timeout for async processing) + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); assertSame(event, child1Event); // child2 should still have the event available - Event child2Event = childQueue2.dequeueEventItem(-1).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); assertSame(event, child2Event); - // Parent should still have the event available - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - assertSame(event, parentEvent); + // child3 should still have the event available + Event child3Event = childQueue3.dequeueEventItem(5000).getEvent(); + assertSame(event, child3Event); } @Test public void testCloseImmediatePropagationToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = createQueueWithEventBus(TASK_ID); EventQueue childQueue = parentQueue.tap(); // Add events to both parent and child @@ -166,7 +238,7 @@ public void testCloseImmediatePropagationToChildren() throws Exception { assertFalse(childQueue.isClosed()); try { - assertNotNull(childQueue.dequeueEventItem(-1)); // Child has the event + assertNotNull(childQueue.dequeueEventItem(5000)); // Child has the event (use timeout) } catch (EventQueueClosedException e) { // This is fine if queue closed before dequeue } @@ -187,27 +259,37 @@ public void testCloseImmediatePropagationToChildren() throws Exception { @Test public void testEnqueueEventWhenClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.close(); // Close the queue first - assertTrue(queue.isClosed()); + childQueue.close(); // Close the child queue first (removes from children list) + assertTrue(childQueue.isClosed()); + + // Create a new child queue BEFORE enqueuing (ensures it's in children list for distribution) + EventQueue newChildQueue = mainQueue.tap(); // MainQueue accepts events even when closed (for replication support) // This ensures late-arriving replicated events can be enqueued to closed queues - queue.enqueueEvent(event); + // Note: MainEventBusProcessor runs asynchronously, so we use dequeueEventItem with timeout + mainQueue.enqueueEvent(event); - // Event should be available for dequeuing - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // New child queue should receive the event (old closed child was removed from children list) + EventQueueItem item = newChildQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); - // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + // Now new child queue is closed and empty, should throw exception + newChildQueue.close(); + assertThrows(EventQueueClosedException.class, () -> newChildQueue.dequeueEventItem(-1)); } @Test public void testDequeueEventWhenClosedAndEmpty() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build().tap(); queue.close(); assertTrue(queue.isClosed()); @@ -217,19 +299,27 @@ public void testDequeueEventWhenClosedAndEmpty() throws Exception { @Test public void testDequeueEventWhenClosedButHasEvents() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.enqueueEvent(event); - queue.close(); // Graceful close - events should remain - assertTrue(queue.isClosed()); + // Use callback to wait for event processing instead of polling + waitForEventProcessing(() -> mainQueue.enqueueEvent(event)); - // Should still be able to dequeue existing events - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // At this point, event has been processed and distributed to childQueue + childQueue.close(); // Graceful close - events should remain + assertTrue(childQueue.isClosed()); + + // Should still be able to dequeue existing events from closed queue + EventQueueItem item = childQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + assertThrows(EventQueueClosedException.class, () -> childQueue.dequeueEventItem(-1)); } @Test @@ -244,7 +334,9 @@ public void testEnqueueAndDequeueEvent() throws Exception { public void testDequeueEventNoWait() throws Exception { Event event = fromJson(MINIMAL_TASK, Task.class); eventQueue.enqueueEvent(event); - Event dequeuedEvent = eventQueue.dequeueEventItem(-1).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); } @@ -257,7 +349,7 @@ public void testDequeueEventEmptyQueueNoWait() throws Exception { @Test public void testDequeueEventWait() throws Exception { Event event = TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -271,7 +363,7 @@ public void testDequeueEventWait() throws Exception { @Test public void testTaskDone() throws Exception { Event event = TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -347,7 +439,7 @@ public void testCloseIdempotent() throws Exception { assertTrue(eventQueue.isClosed()); // Test with immediate close as well - EventQueue eventQueue2 = EventQueue.builder().build(); + EventQueue eventQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); eventQueue2.close(true); assertTrue(eventQueue2.isClosed()); @@ -361,19 +453,20 @@ public void testCloseIdempotent() throws Exception { */ @Test public void testCloseChildQueues() throws Exception { - EventQueue childQueue = eventQueue.tap(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue = mainQueue.tap(); assertTrue(childQueue != null); // Graceful close - parent closes but children remain open - eventQueue.close(); - assertTrue(eventQueue.isClosed()); + mainQueue.close(); + assertTrue(mainQueue.isClosed()); assertFalse(childQueue.isClosed()); // Child NOT closed on graceful parent close // Immediate close - parent force-closes all children - EventQueue parentQueue2 = EventQueue.builder().build(); - EventQueue childQueue2 = parentQueue2.tap(); - parentQueue2.close(true); // immediate=true - assertTrue(parentQueue2.isClosed()); + EventQueue mainQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue2 = mainQueue2.tap(); + mainQueue2.close(true); // immediate=true + assertTrue(mainQueue2.isClosed()); assertTrue(childQueue2.isClosed()); // Child IS closed on immediate parent close } @@ -383,7 +476,7 @@ public void testCloseChildQueues() throws Exception { */ @Test public void testMainQueueReferenceCountingStaysOpenWithActiveChildren() throws Exception { - EventQueue mainQueue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue child1 = mainQueue.tap(); EventQueue child2 = mainQueue.tap(); diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java index 39201c1f6..6c9ed4a17 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -1,8 +1,39 @@ package io.a2a.server.events; +import java.util.concurrent.atomic.AtomicInteger; + public class EventQueueUtil { - // Since EventQueue.builder() is package protected, add a method to expose it - public static EventQueue.EventQueueBuilder getEventQueueBuilder() { - return EventQueue.builder(); + // Counter for generating unique test taskIds + private static final AtomicInteger TASK_ID_COUNTER = new AtomicInteger(0); + + /** + * Get an EventQueue builder pre-configured with the shared test MainEventBus and a unique taskId. + *

    + * Note: Returns MainQueue - tests should call .tap() if they need to consume events. + *

    + * + * @return builder with TEST_EVENT_BUS and unique taskId already set + */ + public static EventQueue.EventQueueBuilder getEventQueueBuilder(MainEventBus eventBus) { + return EventQueue.builder(eventBus) + .taskId("test-task-" + TASK_ID_COUNTER.incrementAndGet()); + } + + /** + * Start a MainEventBusProcessor instance. + * + * @param processor the processor to start + */ + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + /** + * Stop a MainEventBusProcessor instance. + * + * @param processor the processor to stop + */ + public static void stop(MainEventBusProcessor processor) { + processor.stop(); } } diff --git a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java index 1eca1b739..3e09ff2af 100644 --- a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java @@ -14,7 +14,10 @@ import java.util.concurrent.ExecutionException; import java.util.stream.IntStream; +import io.a2a.server.tasks.InMemoryTaskStore; import io.a2a.server.tasks.MockTaskStateProvider; +import io.a2a.server.tasks.PushNotificationSender; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -22,17 +25,30 @@ public class InMemoryQueueManagerTest { private InMemoryQueueManager queueManager; private MockTaskStateProvider taskStateProvider; + private InMemoryTaskStore taskStore; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void setUp() { taskStateProvider = new MockTaskStateProvider(); - queueManager = new InMemoryQueueManager(taskStateProvider); + taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(taskStateProvider, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + } + + @AfterEach + public void tearDown() { + EventQueueUtil.stop(mainEventBusProcessor); } @Test public void testAddNewQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); @@ -43,8 +59,8 @@ public void testAddNewQueue() { @Test public void testAddExistingQueueThrowsException() { String taskId = "test_task_id"; - EventQueue queue1 = EventQueue.builder().build(); - EventQueue queue2 = EventQueue.builder().build(); + EventQueue queue1 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue queue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue1); @@ -56,7 +72,7 @@ public void testAddExistingQueueThrowsException() { @Test public void testGetExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue result = queueManager.get(taskId); @@ -73,7 +89,7 @@ public void testGetNonexistentQueue() { @Test public void testTapExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue tappedQueue = queueManager.tap(taskId); @@ -94,7 +110,7 @@ public void testTapNonexistentQueue() { @Test public void testCloseExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); queueManager.close(taskId); @@ -129,7 +145,7 @@ public void testCreateOrTapNewQueue() { @Test public void testCreateOrTapExistingQueue() { String taskId = "test_task_id"; - EventQueue originalQueue = EventQueue.builder().build(); + EventQueue originalQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, originalQueue); EventQueue result = queueManager.createOrTap(taskId); @@ -151,7 +167,7 @@ public void testConcurrentOperations() throws InterruptedException, ExecutionExc // Add tasks concurrently List> addFutures = taskIds.stream() .map(taskId -> CompletableFuture.supplyAsync(() -> { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); return taskId; })) diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index ea5bbe797..9c64f03f9 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -26,7 +26,10 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.tasks.BasePushNotificationSender; import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; @@ -66,6 +69,8 @@ public class AbstractA2ARequestHandlerTest { private static final String PREFERRED_TRANSPORT = "preferred-transport"; private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties"; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + protected AgentExecutor executor; protected TaskStore taskStore; protected RequestHandler requestHandler; @@ -73,6 +78,8 @@ public class AbstractA2ARequestHandlerTest { protected AgentExecutorMethod agentExecutorCancel; protected InMemoryQueueManager queueManager; protected TestHttpClient httpClient; + protected MainEventBus mainEventBus; + protected MainEventBusProcessor mainEventBusProcessor; protected final Executor internalExecutor = Executors.newCachedThreadPool(); @@ -96,19 +103,31 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro InMemoryTaskStore inMemoryTaskStore = new InMemoryTaskStore(); taskStore = inMemoryTaskStore; - queueManager = new InMemoryQueueManager(inMemoryTaskStore); + + // Create push notification components BEFORE MainEventBusProcessor httpClient = new TestHttpClient(); PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); + // Create MainEventBus and MainEventBusProcessor (production code path) + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); + executor, taskStore, queueManager, pushConfigStore, mainEventBusProcessor, internalExecutor, internalExecutor); } @AfterEach public void cleanup() { agentExecutorExecute = null; agentExecutorCancel = null; + + // Stop MainEventBusProcessor background thread + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } } protected static AgentCard createAgentCard(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory) { diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java index 293babe4e..e69de29bb 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java @@ -1,1001 +0,0 @@ -package io.a2a.server.requesthandlers; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; - -import io.a2a.server.ServerCallContext; -import io.a2a.server.agentexecution.AgentExecutor; -import io.a2a.server.agentexecution.RequestContext; -import io.a2a.server.auth.UnauthenticatedUser; -import io.a2a.server.events.EventQueue; -import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; -import io.a2a.server.tasks.InMemoryTaskStore; -import io.a2a.server.tasks.TaskUpdater; -import io.a2a.spec.A2AError; -import io.a2a.spec.ListTaskPushNotificationConfigParams; -import io.a2a.spec.ListTaskPushNotificationConfigResult; -import io.a2a.spec.Message; -import io.a2a.spec.MessageSendConfiguration; -import io.a2a.spec.MessageSendParams; -import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.Task; -import io.a2a.spec.TaskState; -import io.a2a.spec.TaskStatus; -import io.a2a.spec.TextPart; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; - -/** - * Comprehensive tests for DefaultRequestHandler, backported from Python's - * test_default_request_handler.py. These tests cover core functionality that - * is transport-agnostic and should work across JSON-RPC, gRPC, and REST. - * - * Background cleanup and task tracking tests are from Python PR #440 and #472. - */ -public class DefaultRequestHandlerTest { - - private DefaultRequestHandler requestHandler; - private InMemoryTaskStore taskStore; - private InMemoryQueueManager queueManager; - private TestAgentExecutor agentExecutor; - private ServerCallContext serverCallContext; - - @BeforeEach - void setUp() { - taskStore = new InMemoryTaskStore(); - // Pass taskStore as TaskStateProvider to queueManager for task-aware queue management - queueManager = new InMemoryQueueManager(taskStore); - agentExecutor = new TestAgentExecutor(); - - requestHandler = DefaultRequestHandler.create( - agentExecutor, - taskStore, - queueManager, - null, // pushConfigStore - null, // pushSender - Executors.newCachedThreadPool() - ); - - serverCallContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of(), Set.of()); - } - - /** - * Test that multiple blocking messages to the same task work correctly - * when agent doesn't emit final events (fire-and-forget pattern). - * This replicates TCK test: test_message_send_continue_task - */ - @Test - @Timeout(10) - void testBlockingMessageContinueTask() throws Exception { - String taskId = "continue-task-1"; - String contextId = "continue-ctx-1"; - - // Configure agent to NOT complete tasks (like TCK fire-and-forget agent) - agentExecutor.setExecuteCallback((context, queue) -> { - Task task = context.getTask(); - if (task == null) { - // First message: create SUBMITTED task - task = Task.builder() - .id(context.getTaskId()) - .contextId(context.getContextId()) - .status(new TaskStatus(TaskState.SUBMITTED)) - .build(); - } else { - // Subsequent messages: emit WORKING task (non-final) - task = Task.builder() - .id(context.getTaskId()) - .contextId(context.getContextId()) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - } - queue.enqueueEvent(task); - // Don't complete - just return (fire-and-forget) - }); - - // First blocking message - should return SUBMITTED task - Message message1 = Message.builder() - .messageId("msg-1") - .role(Message.Role.USER) - .parts(new TextPart("first message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params1 = new MessageSendParams(message1, null, null, ""); - Object result1 = requestHandler.onMessageSend(params1, serverCallContext); - - assertTrue(result1 instanceof Task); - Task task1 = (Task) result1; - assertTrue(task1.id().equals(taskId)); - assertTrue(task1.status().state() == TaskState.SUBMITTED); - - // Second blocking message to SAME taskId - should not hang - Message message2 = Message.builder() - .messageId("msg-2") - .role(Message.Role.USER) - .parts(new TextPart("second message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params2 = new MessageSendParams(message2, null, null, ""); - Object result2 = requestHandler.onMessageSend(params2, serverCallContext); - - // Should complete successfully (not timeout) - assertTrue(result2 instanceof Task); - } - - /** - * Test that background cleanup tasks are properly tracked and cleared. - * Backported from Python test: test_background_cleanup_task_is_tracked_and_cleared - */ - @Test - @Timeout(10) - void testBackgroundCleanupTaskIsTrackedAndCleared() throws Exception { - String taskId = "track-task-1"; - String contextId = "track-ctx-1"; - - // Create a task that will trigger background cleanup - Task task = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.SUBMITTED)) - .build(); - - taskStore.save(task); - - Message message = Message.builder() - .messageId("msg-track") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Set up agent to finish quickly so cleanup runs - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowAgentFinish = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - try { - allowAgentFinish.await(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - - // Start streaming (this will create background tasks) - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Allow agent to finish, which should trigger cleanup - allowAgentFinish.countDown(); - - // Give some time for background tasks to be tracked and cleaned up - Thread.sleep(1000); - - // Background tasks should eventually be cleared - // Note: We can't directly access the backgroundTasks field without reflection, - // but the test verifies the mechanism doesn't hang or leak tasks - assertTrue(true, "Background cleanup completed without hanging"); - } - - /** - * Test that client disconnect triggers background cleanup and producer continues. - * Backported from Python test: test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues - */ - @Test - @Timeout(10) - void testStreamingClientDisconnectTriggersBackgroundCleanup() throws Exception { - String taskId = "disc-task-1"; - String contextId = "disc-ctx-1"; - - Message message = Message.builder() - .messageId("mid") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Agent should start and then wait - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowAgentFinish = new CountDownLatch(1); - AtomicBoolean agentCompleted = new AtomicBoolean(false); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - try { - allowAgentFinish.await(10, TimeUnit.SECONDS); - agentCompleted.set(true); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - - // Start streaming - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start executing"); - - // Simulate client disconnect by not consuming the stream - // In real scenarios, the client would close the connection - - // Agent should still be running (not finished immediately on "disconnect") - Thread.sleep(500); - assertTrue(agentExecutor.isExecuting(), "Producer should still be running after simulated disconnect"); - - // Allow agent to finish - allowAgentFinish.countDown(); - - // Wait a bit for completion - Thread.sleep(1000); - - assertTrue(agentCompleted.get(), "Agent should have completed execution"); - } - - /** - * Test that resubscription works after client disconnect. - * Backported from Python test: test_stream_disconnect_then_resubscribe_receives_future_events - */ - @Test - @Timeout(15) - void testStreamDisconnectThenResubscribeReceivesFutureEvents() throws Exception { - String taskId = "reconn-task-1"; - String contextId = "reconn-ctx-1"; - - // Create initial task - Task initialTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - - taskStore.save(initialTask); - - Message message = Message.builder() - .messageId("msg-reconn") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Set up agent to emit events with controlled timing - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowSecondEvent = new CountDownLatch(1); - CountDownLatch allowFinish = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit first event - Task firstEvent = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(firstEvent); - - // Wait for permission to emit second event - try { - allowSecondEvent.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Emit second event - Task secondEvent = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - queue.enqueueEvent(secondEvent); - - try { - allowFinish.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - - // Start streaming and simulate getting first event then disconnecting - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start and emit first event - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Simulate client disconnect (in real scenario, client would close connection) - // The background cleanup should keep the producer running - - // Now try to resubscribe to the task - io.a2a.spec.TaskIdParams resubParams = new io.a2a.spec.TaskIdParams(taskId); - - // Allow agent to emit second event - allowSecondEvent.countDown(); - - // Try resubscription - this should work because queue is still alive - var resubResult = requestHandler.onResubscribeToTask(resubParams, serverCallContext); - // If we get here without exception, resubscription worked - assertTrue(true, "Resubscription succeeded"); - - // Clean up - allowFinish.countDown(); - } - - /** - * Test that task state is persisted to task store after client disconnect. - * Backported from Python test: test_disconnect_persists_final_task_to_store - */ - @Test - @Timeout(15) - void testDisconnectPersistsFinalTaskToStore() throws Exception { - String taskId = "persist-task-1"; - String contextId = "persist-ctx-1"; - - Message message = Message.builder() - .messageId("msg-persist") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Agent that completes after some delay - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowCompletion = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit working status - Task workingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(workingTask); - - try { - allowCompletion.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Emit final completed status - Task completedTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - queue.enqueueEvent(completedTask); - }); - - // Start streaming and simulate client disconnect - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Simulate client disconnect by not consuming the stream further - // In real scenarios, the reactive stream would be cancelled - - // Allow agent to complete in background - allowCompletion.countDown(); - - // Give time for background processing to persist the final state - Thread.sleep(2000); - - // Verify the final task state was persisted despite client disconnect - Task persistedTask = taskStore.get(taskId); - if (persistedTask != null) { - // If task was persisted, it should have the final state - assertTrue( - persistedTask.status().state() == TaskState.COMPLETED || - persistedTask.status().state() == TaskState.WORKING, - "Task should be persisted with working or completed state, got: " + persistedTask.status().state() - ); - } - // Note: In some architectures, the task might not be persisted if the - // background consumption isn't implemented. This test documents the expected behavior. - } - - /** - * Test that blocking message call waits for agent to finish and returns complete Task - * even when agent does fire-and-forget (emits non-final state and returns). - * - * Expected behavior: - * 1. Agent emits WORKING state with artifacts - * 2. Agent's execute() method returns WITHOUT emitting final state - * 3. Blocking onMessageSend() should wait for agent execution to complete - * 4. Blocking onMessageSend() should wait for all queued events to be processed - * 5. Returned Task should have WORKING state with all artifacts included - * - * This tests fire-and-forget pattern with blocking calls. - */ - @Test - @Timeout(15) - void testBlockingFireAndForgetReturnsNonFinalTask() throws Exception { - String taskId = "blocking-fire-forget-task"; - String contextId = "blocking-fire-forget-ctx"; - - Message message = Message.builder() - .messageId("msg-blocking-fire-forget") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent that does fire-and-forget: emits WORKING with artifact but never completes - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - - // Start work (WORKING state) - updater.startWork(); - - // Add artifact - updater.addArtifact( - List.of(new TextPart("Fire and forget artifact")), - "artifact-1", "FireForget", null); - - // Agent returns WITHOUT calling updater.complete() - // Task stays in WORKING state (non-final) - }); - - // Call blocking onMessageSend - should wait for agent to finish - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // The returned result should be a Task in WORKING state with artifact - assertTrue(result instanceof Task, "Result should be a Task"); - Task returnedTask = (Task) result; - - // Verify task is in WORKING state (non-final, fire-and-forget) - assertEquals(TaskState.WORKING, returnedTask.status().state(), - "Returned task should be WORKING (fire-and-forget), got: " + returnedTask.status().state()); - - // Verify artifacts are included in the returned task - assertNotNull(returnedTask.artifacts(), - "Returned task should have artifacts"); - assertTrue(returnedTask.artifacts().size() >= 1, - "Returned task should have at least 1 artifact, got: " + - returnedTask.artifacts().size()); - } - - /** - * Test that non-blocking message call returns immediately and persists all events in background. - * - * Expected behavior: - * 1. Non-blocking call returns immediately with first event (WORKING state) - * 2. Agent continues running in background and produces more events - * 3. Background consumption continues and persists all events to TaskStore - * 4. Final task state (COMPLETED) is persisted in background - */ - @Test - @Timeout(15) - void testNonBlockingMessagePersistsAllEventsInBackground() throws Exception { - String taskId = "blocking-persist-task"; - String contextId = "blocking-persist-ctx"; - - Message message = Message.builder() - .messageId("msg-nonblocking-persist") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - // Use default non-blocking behavior - MessageSendConfiguration config = MessageSendConfiguration.builder() - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent that produces multiple events with delays - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch firstEventEmitted = new CountDownLatch(1); - CountDownLatch allowCompletion = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit first event (WORKING state) - Task workingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(workingTask); - firstEventEmitted.countDown(); - - // Sleep to ensure the non-blocking call has returned before we emit more events - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Wait for permission to complete - try { - allowCompletion.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Emit final event (COMPLETED state) - // This event should be persisted to TaskStore in background - Task completedTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - queue.enqueueEvent(completedTask); - }); - - // Call non-blocking onMessageSend - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // Assertion 1: The immediate result should be the first event (WORKING) - assertTrue(result instanceof Task, "Result should be a Task"); - Task immediateTask = (Task) result; - assertEquals(TaskState.WORKING, immediateTask.status().state(), - "Non-blocking should return immediately with WORKING state, got: " + immediateTask.status().state()); - - // At this point, the non-blocking call has returned, but the agent is still running - - // Allow the agent to emit the final COMPLETED event - allowCompletion.countDown(); - - // Assertion 2: Poll for the final task state to be persisted in background - // Use polling loop instead of fixed sleep for faster and more reliable test - long timeoutMs = 5000; - long startTime = System.currentTimeMillis(); - Task persistedTask = null; - boolean completedStateFound = false; - - while (System.currentTimeMillis() - startTime < timeoutMs) { - persistedTask = taskStore.get(taskId); - if (persistedTask != null && persistedTask.status().state() == TaskState.COMPLETED) { - completedStateFound = true; - break; - } - Thread.sleep(100); // Poll every 100ms - } - - assertTrue(persistedTask != null, "Task should be persisted to store"); - assertTrue( - completedStateFound, - "Final task state should be COMPLETED (background consumption should have processed it), got: " + - (persistedTask != null ? persistedTask.status().state() : "null") + - " after " + (System.currentTimeMillis() - startTime) + "ms" - ); - } - - /** - * Test the BIG idea: MainQueue stays open for non-final tasks even when all children close. - * This enables fire-and-forget tasks and late resubscriptions. - */ - @Test - @Timeout(15) - void testMainQueueStaysOpenForNonFinalTasks() throws Exception { - String taskId = "fire-and-forget-task"; - String contextId = "fire-ctx"; - - // Create initial task in WORKING state (non-final) - Task initialTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - taskStore.save(initialTask); - - Message message = Message.builder() - .messageId("msg-fire") - .role(Message.Role.USER) - .parts(new TextPart("fire and forget")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Agent that emits WORKING status but never completes (fire-and-forget pattern) - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowAgentFinish = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit WORKING status (non-final) - Task workingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(workingTask); - - // Don't emit final state - just wait and finish - try { - allowAgentFinish.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - // Agent finishes WITHOUT emitting final task state - }); - - // Start streaming - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start and emit WORKING event - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Give time for WORKING event to be processed - Thread.sleep(500); - - // Simulate client disconnect - this closes the ChildQueue - // but MainQueue should stay open because task is non-final - - // Allow agent to finish - allowAgentFinish.countDown(); - - // Give time for agent to finish and cleanup to run - Thread.sleep(2000); - - // THE BIG IDEA TEST: Resubscription should work because MainQueue is still open - // Even though: - // 1. The original ChildQueue closed (client disconnected) - // 2. The agent finished executing - // 3. Task is still in non-final WORKING state - // Therefore: MainQueue should still be open for resubscriptions - - io.a2a.spec.TaskIdParams resubParams = new io.a2a.spec.TaskIdParams(taskId); - var resubResult = requestHandler.onResubscribeToTask(resubParams, serverCallContext); - - // If we get here without exception, the BIG idea works! - assertTrue(true, "Resubscription succeeded - MainQueue stayed open for non-final task"); - } - - /** - * Test that MainQueue DOES close when task is finalized. - * This ensures Level 2 protection doesn't prevent cleanup of completed tasks. - */ - @Test - @Timeout(15) - void testMainQueueClosesForFinalizedTasks() throws Exception { - String taskId = "completed-task"; - String contextId = "completed-ctx"; - - // Create initial task in COMPLETED state (already finalized) - Task completedTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - taskStore.save(completedTask); - - // Create a queue for this task - EventQueue mainQueue = queueManager.createOrTap(taskId); - assertTrue(mainQueue != null, "Queue should be created"); - - // Close the child queue (simulating client disconnect) - mainQueue.close(); - - // Give time for cleanup callback to run - Thread.sleep(1000); - - // Since the task is finalized (COMPLETED), the MainQueue should be removed from the map - // This tests that Level 2 protection (childClosing check) allows cleanup for finalized tasks - EventQueue queue = queueManager.get(taskId); - assertTrue(queue == null || queue.isClosed(), - "Queue for finalized task should be null or closed"); - } - - /** - * Test that blocking message call returns a Task with ALL artifacts included. - * This reproduces the reported bug: blocking call returns before artifacts are processed. - * - * Expected behavior: - * 1. Agent emits multiple artifacts via TaskUpdater - * 2. Blocking onMessageSend() should wait for ALL events to be processed - * 3. Returned Task should have all artifacts included in COMPLETED state - * - * Bug manifestation: - * - onMessageSend() returns after first event - * - Artifacts are still being processed in background - * - Returned Task is incomplete - */ - @Test - @Timeout(15) - void testBlockingCallReturnsCompleteTaskWithArtifacts() throws Exception { - String taskId = "blocking-artifacts-task"; - String contextId = "blocking-artifacts-ctx"; - - Message message = Message.builder() - .messageId("msg-blocking-artifacts") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent that uses TaskUpdater to emit multiple artifacts (like real agents do) - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - - // Start work (WORKING state) - updater.startWork(); - - // Add first artifact - updater.addArtifact( - List.of(new TextPart("First artifact")), - "artifact-1", "First", null); - - // Add second artifact - updater.addArtifact( - List.of(new TextPart("Second artifact")), - "artifact-2", "Second", null); - - // Complete the task - updater.complete(); - }); - - // Call blocking onMessageSend - should wait for ALL events - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // The returned result should be a Task with ALL artifacts - assertTrue(result instanceof Task, "Result should be a Task"); - Task returnedTask = (Task) result; - - // Verify task is completed - assertEquals(TaskState.COMPLETED, returnedTask.status().state(), - "Returned task should be COMPLETED"); - - // Verify artifacts are included in the returned task - assertNotNull(returnedTask.artifacts(), - "Returned task should have artifacts"); - assertTrue(returnedTask.artifacts().size() >= 2, - "Returned task should have at least 2 artifacts, got: " + - returnedTask.artifacts().size()); - } - - /** - * Test that pushNotificationConfig from SendMessageConfiguration is stored for NEW tasks - * in non-streaming (blocking) mode. This reproduces the bug from issue #84. - * - * Expected behavior: - * 1. Client sends message with pushNotificationConfig in SendMessageConfiguration - * 2. Agent creates a new task - * 3. pushNotificationConfig should be stored in PushNotificationConfigStore - * 4. Config should be retrievable via getInfo() - */ - @Test - @Timeout(10) - void testBlockingMessageStoresPushNotificationConfigForNewTask() throws Exception { - String taskId = "push-config-blocking-new-task"; - String contextId = "push-config-ctx"; - - // Create test config store - InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); - - // Re-create request handler with pushConfigStore - requestHandler = DefaultRequestHandler.create( - agentExecutor, - taskStore, - queueManager, - pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() - ); - - // Create push notification config - PushNotificationConfig pushConfig = PushNotificationConfig.builder() - .id("config-1") - .url("https://example.com/webhook") - .token("test-token-123") - .build(); - - // Create message with pushNotificationConfig - Message message = Message.builder() - .messageId("msg-push-config") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .pushNotificationConfig(pushConfig) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent creates a new task - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - updater.submit(); // Creates new task in SUBMITTED state - updater.complete(); - }); - - // Call blocking onMessageSend - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // Verify result is a task - assertTrue(result instanceof Task, "Result should be a Task"); - Task returnedTask = (Task) result; - assertEquals(taskId, returnedTask.id()); - - // THE KEY ASSERTION: Verify pushNotificationConfig was stored - ListTaskPushNotificationConfigResult storedConfigs = pushConfigStore.getInfo(new ListTaskPushNotificationConfigParams(taskId)); - assertNotNull(storedConfigs, "Push notification config should be stored for new task"); - assertEquals(1, storedConfigs.size(), - "Should have exactly 1 push config stored"); - PushNotificationConfig storedConfig = storedConfigs.configs().get(0).pushNotificationConfig(); - assertEquals("config-1", storedConfig.id()); - assertEquals("https://example.com/webhook", storedConfig.url()); - } - - /** - * Test that pushNotificationConfig is stored for EXISTING tasks. - * This verifies the initMessageSend logic works correctly. - */ - @Test - @Timeout(10) - void testBlockingMessageStoresPushNotificationConfigForExistingTask() throws Exception { - String taskId = "push-config-existing-task"; - String contextId = "push-config-existing-ctx"; - - // Create test config store - InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); - - // Re-create request handler with pushConfigStore - requestHandler = DefaultRequestHandler.create( - agentExecutor, - taskStore, - queueManager, - pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() - ); - - // Create EXISTING task in store - Task existingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - taskStore.save(existingTask); - - // Create push notification config - PushNotificationConfig pushConfig = PushNotificationConfig.builder() - .id("config-existing-1") - .url("https://example.com/existing-webhook") - .token("existing-token-789") - .build(); - - Message message = Message.builder() - .messageId("msg-push-existing") - .role(Message.Role.USER) - .parts(new TextPart("update existing task")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .pushNotificationConfig(pushConfig) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent updates the existing task - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - updater.addArtifact( - List.of(new TextPart("update artifact")), - "artifact-1", "Update", null); - updater.complete(); - }); - - // Call blocking onMessageSend - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // Verify result - assertTrue(result instanceof Task, "Result should be a Task"); - - // Verify pushNotificationConfig was stored (initMessageSend path) - ListTaskPushNotificationConfigResult storedConfigs = pushConfigStore.getInfo(new ListTaskPushNotificationConfigParams(taskId)); - assertNotNull(storedConfigs,"Push notification config should be stored for existing task"); - assertEquals(1, storedConfigs.size(),"Should have exactly 1 push config stored"); - PushNotificationConfig storedConfig = storedConfigs.configs().get(0).pushNotificationConfig(); - assertEquals("config-existing-1", storedConfig.id()); - assertEquals("https://example.com/existing-webhook", storedConfig.url()); - } - - /** - * Simple test agent executor that allows controlling execution timing - */ - private static class TestAgentExecutor implements AgentExecutor { - private ExecuteCallback executeCallback; - private volatile boolean executing = false; - - interface ExecuteCallback { - void call(RequestContext context, EventQueue queue) throws A2AError; - } - - void setExecuteCallback(ExecuteCallback callback) { - this.executeCallback = callback; - } - - boolean isExecuting() { - return executing; - } - - @Override - public void execute(RequestContext context, EventQueue eventQueue) throws A2AError { - executing = true; - try { - if (executeCallback != null) { - // Custom callback is responsible for emitting events - executeCallback.call(context, eventQueue); - } else { - // No custom callback - emit default completion event - Task completedTask = Task.builder() - .id(context.getTaskId()) - .contextId(context.getContextId()) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - eventQueue.enqueueEvent(completedTask); - } - - } finally { - executing = false; - } - } - - @Override - public void cancel(RequestContext context, EventQueue eventQueue) throws A2AError { - // Simple cancel implementation - executing = false; - } - } -} \ No newline at end of file diff --git a/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java b/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java index e814c1c15..e69de29bb 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java @@ -1,49 +0,0 @@ -package io.a2a.server.tasks; - -import static io.a2a.jsonrpc.common.json.JsonUtil.fromJson; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; - -import io.a2a.spec.Task; -import org.junit.jupiter.api.Test; - -public class InMemoryTaskStoreTest { - private static final String TASK_JSON = """ - { - "id": "task-abc", - "contextId" : "session-xyz", - "status": {"state": "submitted"} - }"""; - - @Test - public void testSaveAndGet() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - Task task = fromJson(TASK_JSON, Task.class); - store.save(task); - Task retrieved = store.get(task.id()); - assertSame(task, retrieved); - } - - @Test - public void testGetNonExistent() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - Task retrieved = store.get("nonexistent"); - assertNull(retrieved); - } - - @Test - public void testDelete() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - Task task = fromJson(TASK_JSON, Task.class); - store.save(task); - store.delete(task.id()); - Task retrieved = store.get(task.id()); - assertNull(retrieved); - } - - @Test - public void testDeleteNonExistent() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - store.delete("non-existent"); - } -} diff --git a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java index d64729077..0e25e9aad 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java @@ -11,18 +11,25 @@ import static org.mockito.Mockito.when; import java.util.Collections; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.spec.Event; import io.a2a.spec.EventKind; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -49,7 +56,7 @@ public class ResultAggregatorTest { @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); - aggregator = new ResultAggregator(mockTaskManager, null, testExecutor); + aggregator = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); } // Helper methods for creating sample data @@ -69,13 +76,45 @@ private Task createSampleTask(String taskId, TaskState statusState, String conte .build(); } + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param processor the processor to set callback on + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(MainEventBusProcessor processor, Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + processor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + processor.setCallback(null); + } + } + // Basic functionality tests @Test void testConstructorWithMessage() { Message initialMessage = createSampleMessage("initial", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor, testExecutor); // Test that the message is properly stored by checking getCurrentResult assertEquals(initialMessage, aggregatorWithMessage.getCurrentResult()); @@ -86,7 +125,7 @@ void testConstructorWithMessage() { @Test void testGetCurrentResultWithMessageSet() { Message sampleMessage = createSampleMessage("hola", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor, testExecutor); EventKind result = aggregatorWithMessage.getCurrentResult(); @@ -121,7 +160,7 @@ void testConstructorStoresTaskManagerCorrectly() { @Test void testConstructorWithNullMessage() { - ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor); + ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); Task expectedTask = createSampleTask("null_msg_task", TaskState.WORKING, "ctx1"); when(mockTaskManager.getTask()).thenReturn(expectedTask); @@ -181,7 +220,7 @@ void testMultipleGetCurrentResultCalls() { void testGetCurrentResultWithMessageTakesPrecedence() { // Test that when both message and task are available, message takes precedence Message message = createSampleMessage("priority message", "pri1", Message.Role.USER); - ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor); + ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor, testExecutor); // Even if we set up the task manager to return something, message should take precedence Task task = createSampleTask("should_not_be_returned", TaskState.WORKING, "ctx1"); @@ -197,17 +236,24 @@ void testGetCurrentResultWithMessageTakesPrecedence() { @Test void testConsumeAndBreakNonBlocking() throws Exception { // Test that with blocking=false, the method returns after the first event - Task firstEvent = createSampleTask("non_blocking_task", TaskState.WORKING, "ctx1"); + String taskId = "test-task"; + Task firstEvent = createSampleTask(taskId, TaskState.WORKING, "ctx1"); // After processing firstEvent, the current result will be that task when(mockTaskManager.getTask()).thenReturn(firstEvent); // Create an event queue using QueueManager (which has access to builder) + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); InMemoryQueueManager queueManager = - new InMemoryQueueManager(new MockTaskStateProvider()); + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); - EventQueue queue = queueManager.getEventQueueBuilder("test-task").build(); - queue.enqueueEvent(firstEvent); + // Use callback to wait for event processing (replaces polling) + waitForEventProcessing(processor, () -> queue.enqueueEvent(firstEvent)); // Create real EventConsumer with the queue EventConsumer eventConsumer = @@ -221,11 +267,16 @@ void testConsumeAndBreakNonBlocking() throws Exception { assertEquals(firstEvent, result.eventType()); assertTrue(result.interrupted()); - verify(mockTaskManager).process(firstEvent); - // getTask() is called at least once for the return value (line 255) - // May be called once more if debug logging executes in time (line 209) - // The async consumer may or may not execute before verification, so we accept 1-2 calls - verify(mockTaskManager, atLeast(1)).getTask(); - verify(mockTaskManager, atMost(2)).getTask(); + // NOTE: ResultAggregator no longer calls taskManager.process() + // That responsibility has moved to MainEventBusProcessor for centralized persistence + // + // NOTE: Since firstEvent is a Task, ResultAggregator captures it directly from the queue + // (capturedTask.get() at line 283 in ResultAggregator). Therefore, taskManager.getTask() + // is only called for debug logging in taskIdForLogging() (line 305), which may or may not + // execute depending on timing and log level. We expect 0-1 calls, not 1-2. + verify(mockTaskManager, atMost(1)).getTask(); + + // Cleanup: stop the processor + EventQueueUtil.stop(processor); } } diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java index f14ebc0fe..91010e52e 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java @@ -48,7 +48,7 @@ public void init() throws Exception { @Test public void testGetTaskExisting() { Task expectedTask = minimalTask; - taskStore.save(expectedTask); + taskStore.save(expectedTask, false); Task retrieved = taskManager.getTask(); assertSame(expectedTask, retrieved); } @@ -61,16 +61,16 @@ public void testGetTaskNonExistent() { @Test public void testSaveTaskEventNewTask() throws A2AServerException { - Task saved = taskManager.saveTaskEvent(minimalTask); + taskManager.saveTaskEvent(minimalTask, false); + Task saved = taskManager.getTask(); Task retrieved = taskManager.getTask(); assertSame(minimalTask, retrieved); - assertSame(retrieved, saved); } @Test public void testSaveTaskEventStatusUpdate() throws A2AServerException { Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); TaskStatus newStatus = new TaskStatus( TaskState.WORKING, @@ -88,11 +88,11 @@ public void testSaveTaskEventStatusUpdate() throws A2AServerException { new HashMap<>()); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updated = taskManager.getTask(); assertNotSame(initialTask, updated); - assertSame(updated, saved); assertEquals(initialTask.id(), updated.id()); assertEquals(initialTask.contextId(), updated.contextId()); @@ -114,10 +114,10 @@ public void testSaveTaskEventArtifactUpdate() throws A2AServerException { .contextId(minimalTask.contextId()) .artifact(newArtifact) .build(); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); - assertSame(updatedTask, saved); assertNotSame(initialTask, updatedTask); assertEquals(initialTask.id(), updatedTask.id()); @@ -144,7 +144,8 @@ public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException .isFinal(false) .build(); - Task task = taskManagerWithoutId.saveTaskEvent(event); + taskManagerWithoutId.saveTaskEvent(event, false); + Task task = taskManagerWithoutId.getTask(); assertEquals(event.taskId(), taskManagerWithoutId.getTaskId()); assertEquals(event.contextId(), taskManagerWithoutId.getContextId()); @@ -164,13 +165,13 @@ public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { .status(new TaskStatus(TaskState.WORKING)) .build(); - Task saved = taskManagerWithoutId.saveTaskEvent(task); + taskManagerWithoutId.saveTaskEvent(task, false); + Task saved = taskManager.getTask(); assertEquals(task.id(), taskManagerWithoutId.getTaskId()); assertEquals(task.contextId(), taskManagerWithoutId.getContextId()); Task retrieved = taskManagerWithoutId.getTask(); assertSame(task, retrieved); - assertSame(retrieved, saved); } @Test @@ -194,7 +195,7 @@ public void testTaskArtifactUpdateEventAppendTrueWithExistingArtifact() throws A Task taskWithArtifact = Task.builder(initialTask) .artifacts(Collections.singletonList(existingArtifact)) .build(); - taskStore.save(taskWithArtifact); + taskStore.save(taskWithArtifact, false); // Test: Append new parts to existing artifact Artifact newArtifact = Artifact.builder() @@ -209,7 +210,8 @@ public void testTaskArtifactUpdateEventAppendTrueWithExistingArtifact() throws A .append(true) .build(); - Task updatedTask = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); Artifact updatedArtifact = updatedTask.artifacts().get(0); @@ -223,7 +225,7 @@ public void testTaskArtifactUpdateEventAppendTrueWithExistingArtifact() throws A public void testTaskArtifactUpdateEventAppendTrueWithoutExistingArtifact() throws A2AServerException { // Setup: Create a task without artifacts Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Test: Try to append to non-existent artifact (should be ignored) Artifact newArtifact = Artifact.builder() @@ -238,7 +240,8 @@ public void testTaskArtifactUpdateEventAppendTrueWithoutExistingArtifact() throw .append(true) .build(); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); // Should have no artifacts since append was ignored @@ -257,7 +260,7 @@ public void testTaskArtifactUpdateEventAppendFalseWithExistingArtifact() throws Task taskWithArtifact = Task.builder(initialTask) .artifacts(Collections.singletonList(existingArtifact)) .build(); - taskStore.save(taskWithArtifact); + taskStore.save(taskWithArtifact, false); // Test: Replace existing artifact (append=false) Artifact newArtifact = Artifact.builder() @@ -272,7 +275,8 @@ public void testTaskArtifactUpdateEventAppendFalseWithExistingArtifact() throws .append(false) .build(); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); @@ -294,7 +298,7 @@ public void testTaskArtifactUpdateEventAppendNullWithExistingArtifact() throws A Task taskWithArtifact = Task.builder(initialTask) .artifacts(Collections.singletonList(existingArtifact)) .build(); - taskStore.save(taskWithArtifact); + taskStore.save(taskWithArtifact, false); // Test: Replace existing artifact (append=null, defaults to false) Artifact newArtifact = Artifact.builder() @@ -308,7 +312,8 @@ public void testTaskArtifactUpdateEventAppendNullWithExistingArtifact() throws A .artifact(newArtifact) .build(); // append is null - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); @@ -330,7 +335,7 @@ public void testAddingTaskWithDifferentIdFails() { .build(); assertThrows(A2AServerException.class, () -> { - taskManagerWithId.saveTaskEvent(differentTask); + taskManagerWithId.saveTaskEvent(differentTask, false); }); } @@ -347,7 +352,7 @@ public void testAddingTaskWithDifferentIdViaStatusUpdateFails() { .build(); assertThrows(A2AServerException.class, () -> { - taskManagerWithId.saveTaskEvent(event); + taskManagerWithId.saveTaskEvent(event, false); }); } @@ -368,7 +373,7 @@ public void testAddingTaskWithDifferentIdViaArtifactUpdateFails() { .build(); assertThrows(A2AServerException.class, () -> { - taskManagerWithId.saveTaskEvent(event); + taskManagerWithId.saveTaskEvent(event, false); }); } @@ -392,7 +397,8 @@ public void testTaskWithNoMessageUsesInitialMessage() throws A2AServerException .isFinal(false) .build(); - Task saved = taskManagerWithInitialMessage.saveTaskEvent(event); + taskManagerWithInitialMessage.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task retrieved = taskManagerWithInitialMessage.getTask(); // Check that the task has the initial message in its history @@ -429,7 +435,8 @@ public void testTaskWithMessageDoesNotUseInitialMessage() throws A2AServerExcept .isFinal(false) .build(); - Task saved = taskManagerWithInitialMessage.saveTaskEvent(event); + taskManagerWithInitialMessage.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task retrieved = taskManagerWithInitialMessage.getTask(); // There should now be a history containing the initialMessage @@ -447,7 +454,7 @@ public void testTaskWithMessageDoesNotUseInitialMessage() throws A2AServerExcept public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException { // Test handling of multiple artifacts with the same artifactId Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Add first artifact Artifact artifact1 = Artifact.builder() @@ -460,7 +467,7 @@ public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException .contextId(minimalTask.contextId()) .artifact(artifact1) .build(); - taskManager.saveTaskEvent(event1); + taskManager.saveTaskEvent(event1, false); // Add second artifact with same artifactId (should replace the first) Artifact artifact2 = Artifact.builder() @@ -473,7 +480,7 @@ public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException .contextId(minimalTask.contextId()) .artifact(artifact2) .build(); - taskManager.saveTaskEvent(event2); + taskManager.saveTaskEvent(event2, false); Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); @@ -487,7 +494,7 @@ public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerException { // Test handling of multiple artifacts with different artifactIds Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Add first artifact Artifact artifact1 = Artifact.builder() @@ -500,7 +507,7 @@ public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerExce .contextId(minimalTask.contextId()) .artifact(artifact1) .build(); - taskManager.saveTaskEvent(event1); + taskManager.saveTaskEvent(event1, false); // Add second artifact with different artifactId (should be added) Artifact artifact2 = Artifact.builder() @@ -513,7 +520,7 @@ public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerExce .contextId(minimalTask.contextId()) .artifact(artifact2) .build(); - taskManager.saveTaskEvent(event2); + taskManager.saveTaskEvent(event2, false); Task updatedTask = taskManager.getTask(); assertEquals(2, updatedTask.artifacts().size()); @@ -545,7 +552,7 @@ public void testInvalidTaskIdValidation() { public void testSaveTaskEventMetadataUpdate() throws A2AServerException { // Test that metadata from TaskStatusUpdateEvent gets saved to the task Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); Map newMetadata = new HashMap<>(); newMetadata.put("meta_key_test", "meta_value_test"); @@ -558,7 +565,7 @@ public void testSaveTaskEventMetadataUpdate() throws A2AServerException { .metadata(newMetadata) .build(); - taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); Task updatedTask = taskManager.getTask(); assertEquals(newMetadata, updatedTask.metadata()); @@ -568,7 +575,7 @@ public void testSaveTaskEventMetadataUpdate() throws A2AServerException { public void testSaveTaskEventMetadataUpdateNull() throws A2AServerException { // Test that null metadata in TaskStatusUpdateEvent doesn't affect task Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId(minimalTask.id()) @@ -578,7 +585,7 @@ public void testSaveTaskEventMetadataUpdateNull() throws A2AServerException { .metadata(null) .build(); - taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); Task updatedTask = taskManager.getTask(); // Should preserve original task's metadata (which is likely null for minimal task) @@ -594,7 +601,7 @@ public void testSaveTaskEventMetadataMergeExisting() throws A2AServerException { Task taskWithMetadata = Task.builder(minimalTask) .metadata(originalMetadata) .build(); - taskStore.save(taskWithMetadata); + taskStore.save(taskWithMetadata, false); Map newMetadata = new HashMap<>(); newMetadata.put("new_key", "new_value"); @@ -607,7 +614,7 @@ public void testSaveTaskEventMetadataMergeExisting() throws A2AServerException { .metadata(newMetadata) .build(); - taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); Task updatedTask = taskManager.getTask(); @@ -634,7 +641,8 @@ public void testCreateTaskWithInitialMessage() throws A2AServerException { .isFinal(false) .build(); - Task savedTask = taskManagerWithMessage.saveTaskEvent(event); + taskManagerWithMessage.saveTaskEvent(event, false); + Task savedTask = taskManagerWithMessage.getTask(); // Verify task was created properly assertNotNull(savedTask); @@ -662,7 +670,8 @@ public void testCreateTaskWithoutInitialMessage() throws A2AServerException { .isFinal(false) .build(); - Task savedTask = taskManagerWithoutMessage.saveTaskEvent(event); + taskManagerWithoutMessage.saveTaskEvent(event, false); + Task savedTask = taskManagerWithoutMessage.getTask(); // Verify task was created properly assertNotNull(savedTask); @@ -685,7 +694,8 @@ public void testSaveTaskInternal() throws A2AServerException { .status(new TaskStatus(TaskState.WORKING)) .build(); - Task savedTask = taskManagerWithoutId.saveTaskEvent(newTask); + taskManagerWithoutId.saveTaskEvent(newTask, false); + Task savedTask = taskManagerWithoutId.getTask(); // Verify internal state was updated assertEquals("test-task-id", taskManagerWithoutId.getTaskId()); @@ -716,7 +726,8 @@ public void testUpdateWithMessage() throws A2AServerException { .isFinal(false) .build(); - Task saved = taskManagerWithInitialMessage.saveTaskEvent(event); + taskManagerWithInitialMessage.saveTaskEvent(event, false); + Task saved = taskManagerWithInitialMessage.getTask(); Message updateMessage = Message.builder() .role(Message.Role.USER) diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java index 40f763569..73da17824 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java @@ -14,7 +14,11 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.spec.Event; import io.a2a.spec.Message; import io.a2a.spec.Part; @@ -22,6 +26,7 @@ import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,14 +43,28 @@ public class TaskUpdaterTest { private static final List> SAMPLE_PARTS = List.of(new TextPart("Test message")); + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private TaskUpdater taskUpdater; @BeforeEach public void init() { - eventQueue = EventQueueUtil.getEventQueueBuilder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TEST_TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); RequestContext context = new RequestContext.Builder() .setTaskId(TEST_TASK_ID) .setContextId(TEST_TASK_CONTEXT_ID) @@ -53,10 +72,19 @@ public void init() { taskUpdater = new TaskUpdater(context, eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + @Test public void testAddArtifactWithCustomIdAndName() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "custom-artifact-id", "Custom Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -239,7 +267,9 @@ public void testNewAgentMessageWithMetadata() throws Exception { @Test public void testAddArtifactWithAppendTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -258,7 +288,9 @@ public void testAddArtifactWithAppendTrue() throws Exception { @Test public void testAddArtifactWithLastChunkTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, null, true); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -273,7 +305,9 @@ public void testAddArtifactWithLastChunkTrue() throws Exception { @Test public void testAddArtifactWithAppendAndLastChunk() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, false); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -287,7 +321,9 @@ public void testAddArtifactWithAppendAndLastChunk() throws Exception { @Test public void testAddArtifactGeneratesIdWhenNull() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, null, "Test Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -383,7 +419,9 @@ public void testConcurrentCompletionAttempts() throws Exception { thread2.join(); // Exactly one event should have been queued - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -396,7 +434,10 @@ public void testConcurrentCompletionAttempts() throws Exception { } private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, TaskState state, Message statusMessage) throws Exception { - Event event = eventQueue.dequeueEventItem(0).getEvent(); + // Wait up to 5 seconds for event (async MainEventBusProcessor needs time to distribute) + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -408,6 +449,7 @@ private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, assertEquals(state, tsue.status().state()); assertEquals(statusMessage, tsue.status().message()); + // Check no additional events (still use 0 timeout for this check) assertNull(eventQueue.dequeueEventItem(0)); return tsue; diff --git a/tck/src/main/resources/application.properties b/tck/src/main/resources/application.properties index c68793be4..b23747b00 100644 --- a/tck/src/main/resources/application.properties +++ b/tck/src/main/resources/application.properties @@ -12,6 +12,7 @@ a2a.executor.keep-alive-seconds=60 quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG quarkus.log.category."io.a2a.server.events".level=DEBUG quarkus.log.category."io.a2a.server.tasks".level=DEBUG +io.a2a.server.diagnostics.ThreadStats.level=DEBUG # Log to file for analysis quarkus.log.file.enable=true diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index 724b58613..62acc506e 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -656,7 +656,20 @@ public void testResubscribeExistingTaskSuccess() throws Exception { AtomicReference errorRef = new AtomicReference<>(); // Create consumer to handle resubscribed events + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); BiConsumer consumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!receivedInitialTask.get()) { + if (event instanceof TaskEvent) { + receivedInitialTask.set(true); + // Don't count down latch for initial Task + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent taskUpdateEvent) { if (taskUpdateEvent.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifactEvent) { artifactUpdateEvent.set(artifactEvent); @@ -755,12 +768,25 @@ public void testResubscribeExistingTaskSuccessWithClientConsumers() throws Excep AtomicReference errorRef = new AtomicReference<>(); // Create consumer to handle resubscribed events + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); AgentCard agentCard = createTestAgentCard(); ClientConfig clientConfig = createClientConfig(true); ClientBuilder clientBuilder = Client .builder(agentCard) .addConsumer((evt, agentCard1) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!receivedInitialTask.get()) { + if (evt instanceof TaskEvent) { + receivedInitialTask.set(true); + // Don't count down latch for initial Task + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + evt.getClass().getSimpleName()); + } + } + + // Process subsequent events if (evt instanceof TaskUpdateEvent taskUpdateEvent) { if (taskUpdateEvent.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifactEvent) { artifactUpdateEvent.set(artifactEvent); @@ -918,8 +944,20 @@ public void testMainQueueReferenceCountingWithMultipleConsumers() throws Excepti AtomicReference firstConsumerEvent = new AtomicReference<>(); AtomicBoolean firstUnexpectedEvent = new AtomicBoolean(false); AtomicReference firstErrorRef = new AtomicReference<>(); + AtomicBoolean firstReceivedInitialTask = new AtomicBoolean(false); BiConsumer firstConsumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!firstReceivedInitialTask.get()) { + if (event instanceof TaskEvent) { + firstReceivedInitialTask.set(true); + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent tue && tue.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifact) { firstConsumerEvent.set(artifact); firstConsumerLatch.countDown(); @@ -975,8 +1013,20 @@ public void testMainQueueReferenceCountingWithMultipleConsumers() throws Excepti AtomicReference secondConsumerEvent = new AtomicReference<>(); AtomicBoolean secondUnexpectedEvent = new AtomicBoolean(false); AtomicReference secondErrorRef = new AtomicReference<>(); + AtomicBoolean secondReceivedInitialTask = new AtomicBoolean(false); BiConsumer secondConsumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!secondReceivedInitialTask.get()) { + if (event instanceof TaskEvent) { + secondReceivedInitialTask.set(true); + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent tue && tue.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifact) { secondConsumerEvent.set(artifact); secondConsumerLatch.countDown(); @@ -1316,8 +1366,20 @@ public void testNonBlockingWithMultipleMessages() throws Exception { List resubReceivedEvents = new CopyOnWriteArrayList<>(); AtomicBoolean resubUnexpectedEvent = new AtomicBoolean(false); AtomicReference resubErrorRef = new AtomicReference<>(); + AtomicBoolean resubReceivedInitialTask = new AtomicBoolean(false); BiConsumer resubConsumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!resubReceivedInitialTask.get()) { + if (event instanceof TaskEvent) { + resubReceivedInitialTask.set(true); + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent tue) { resubReceivedEvents.add(tue.getUpdateEvent()); resubEventLatch.countDown(); @@ -1355,6 +1417,7 @@ public void testNonBlockingWithMultipleMessages() throws Exception { AtomicBoolean streamUnexpectedEvent = new AtomicBoolean(false); BiConsumer streamConsumer = (event, agentCard) -> { + // This consumer is for sendMessage() (not resubscribe), so it doesn't get initial TaskEvent if (event instanceof TaskUpdateEvent tue) { streamReceivedEvents.add(tue.getUpdateEvent()); streamEventLatch.countDown(); diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java index 9df23c565..45483f214 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java @@ -31,7 +31,7 @@ public class TestUtilsBean { PushNotificationConfigStore pushNotificationConfigStore; public void saveTask(Task task) { - taskStore.save(task); + taskStore.save(task, false); } public Task getTask(String taskId) { diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index 408205aa2..439d97497 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -242,7 +242,7 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request, A2AExtensions.validateRequiredExtensions(getAgentCardInternal(), context); MessageSendParams params = FromProto.messageSendParams(request); Flow.Publisher publisher = getRequestHandler().onMessageSendStream(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -264,7 +264,7 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, ServerCallContext context = createCallContext(responseObserver); TaskIdParams params = FromProto.taskIdParams(request); Flow.Publisher publisher = getRequestHandler().onResubscribeToTask(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -275,7 +275,8 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, } private void convertToStreamResponse(Flow.Publisher publisher, - StreamObserver responseObserver) { + StreamObserver responseObserver, + ServerCallContext context) { CompletableFuture.runAsync(() -> { publisher.subscribe(new Flow.Subscriber() { private Flow.Subscription subscription; @@ -285,6 +286,18 @@ public void onSubscribe(Flow.Subscription subscription) { this.subscription = subscription; subscription.request(1); + // Detect gRPC client disconnect and call EventConsumer.cancel() directly + // This stops the polling loop without relying on subscription cancellation propagation + Context grpcContext = Context.current(); + grpcContext.addListener(new Context.CancellationListener() { + @Override + public void cancelled(Context ctx) { + LOGGER.fine(() -> "gRPC call cancelled by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + } + }, getExecutor()); + // Notify tests that we are subscribed Runnable runnable = streamingSubscribedRunnable; if (runnable != null) { @@ -305,6 +318,8 @@ public void onNext(StreamingEventKind event) { @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + subscription.cancel(); if (throwable instanceof A2AError jsonrpcError) { handleError(responseObserver, jsonrpcError); } else { @@ -329,6 +344,9 @@ public void getExtendedAgentCard(io.a2a.grpc.GetExtendedAgentCardRequest request if (extendedAgentCard != null) { responseObserver.onNext(ToProto.agentCard(extendedAgentCard)); responseObserver.onCompleted(); + } else { + // Extended agent card not configured - return error instead of hanging + handleError(responseObserver, new ExtendedAgentCardNotConfiguredError(null, "Extended agent card not configured", null)); } } catch (Throwable t) { handleInternalError(responseObserver, t); diff --git a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java index 690d69a87..afed1329f 100644 --- a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java +++ b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java @@ -84,7 +84,7 @@ public class GrpcHandlerTest extends AbstractA2ARequestHandlerTest { @Test public void testOnGetTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); GetTaskRequest request = GetTaskRequest.newBuilder() .setName("tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()) .build(); @@ -120,7 +120,7 @@ public void testOnGetTaskNotFound() throws Exception { @Test public void testOnCancelTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { // We need to cancel the task or the EventConsumer never finds a 'final' event. @@ -151,7 +151,7 @@ public void testOnCancelTaskSuccess() throws Exception { @Test public void testOnCancelTaskNotSupported() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { throw new UnsupportedOperationError(); @@ -199,7 +199,7 @@ public void testOnMessageNewMessageSuccess() throws Exception { @Test public void testOnMessageNewMessageWithExistingTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getMessage()); }; @@ -281,8 +281,7 @@ public void testPushNotificationsNotSupportedError() throws Exception { @Test public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -293,8 +292,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { @Test public void testOnSetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -330,7 +328,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception io.a2a.spec.Task task = io.a2a.spec.Task.builder(AbstractA2ARequestHandlerTest.MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); List results = new ArrayList<>(); List errors = new ArrayList<>(); @@ -379,7 +377,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() throws Excep io.a2a.spec.Task task = io.a2a.spec.Task.builder(AbstractA2ARequestHandlerTest.MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); // This is used to send events from a mock List events = List.of( @@ -424,9 +422,14 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() throws Excep @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - List events = List.of( - AbstractA2ARequestHandlerTest.MINIMAL_TASK, + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); + List events = List.of( + AbstractA2ARequestHandlerTest.MINIMAL_TASK, TaskArtifactUpdateEvent.builder() .taskId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()) .contextId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId()) @@ -493,13 +496,16 @@ public void onCompleted() { Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); - curr = httpClient.tasks.get(2); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); - Assertions.assertEquals(1, curr.artifacts().size()); - Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); - Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + curr = httpClient.tasks.get(2); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); + Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); + Assertions.assertEquals(1, curr.artifacts().size()); + Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); + Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test @@ -517,7 +523,7 @@ public void testOnResubscribeNoExistingTaskError() throws Exception { @Test public void testOnResubscribeExistingTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); queueManager.createOrTap(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()); agentExecutorExecute = (context, eventQueue) -> { @@ -542,9 +548,18 @@ public void testOnResubscribeExistingTaskSuccess() throws Exception { streamRecorder.awaitCompletion(5, TimeUnit.SECONDS); List result = streamRecorder.getValues(); Assertions.assertNotNull(result); - Assertions.assertEquals(1, result.size()); - StreamResponse response = result.get(0); - Assertions.assertTrue(response.hasMessage()); + // Per A2A Protocol Spec 3.1.6, resubscribe sends current Task as first event, + // followed by the Message from the agent executor + Assertions.assertEquals(2, result.size()); + + // ENFORCE that first event is Task + Assertions.assertTrue(result.get(0).hasTask(), + "First event on resubscribe MUST be Task (current state)"); + + // Second event should be Message from agent executor + StreamResponse response = result.get(1); + Assertions.assertTrue(response.hasMessage(), + "Expected Message after initial Task"); assertEquals(GRPC_MESSAGE, response.getMessage()); Assertions.assertNull(streamRecorder.getError()); } @@ -552,7 +567,7 @@ public void testOnResubscribeExistingTaskSuccess() throws Exception { @Test public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); queueManager.createOrTap(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()); List events = List.of( @@ -627,7 +642,7 @@ public void testOnMessageStreamInternalError() throws Exception { @Test public void testListPushNotificationConfig() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -653,7 +668,7 @@ public void testListPushNotificationConfig() throws Exception { public void testListPushNotificationConfigNotSupported() throws Exception { AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(true, false, true); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -668,10 +683,9 @@ public void testListPushNotificationConfigNotSupported() throws Exception { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -702,7 +716,7 @@ public void testListPushNotificationConfigTaskNotFound() { @Test public void testDeletePushNotificationConfig() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -725,7 +739,7 @@ public void testDeletePushNotificationConfig() throws Exception { public void testDeletePushNotificationConfigNotSupported() throws Exception { AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(true, false, true); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -741,8 +755,7 @@ public void testDeletePushNotificationConfigNotSupported() throws Exception { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); DeleteTaskPushNotificationConfigRequest request = DeleteTaskPushNotificationConfigRequest.newBuilder() @@ -1155,7 +1168,7 @@ private StreamRecorder sendMessageRequest(GrpcHandler handl } private StreamRecorder createTaskPushNotificationConfigRequest(GrpcHandler handler, String name) throws Exception { - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); PushNotificationConfig config = PushNotificationConfig.newBuilder() .setUrl("http://example.com") .setId("config456") diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index b43c28029..4dc151626 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -3,6 +3,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.ArrayList; import java.util.Collections; @@ -30,6 +32,8 @@ import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.ListTasksResult; import io.a2a.jsonrpc.common.wrappers.SendMessageRequest; import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; @@ -37,12 +41,11 @@ import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; -import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; -import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.SubscribeToTaskRequest; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventConsumer; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; import io.a2a.server.requesthandlers.DefaultRequestHandler; import io.a2a.server.tasks.ResultAggregator; @@ -52,16 +55,15 @@ import io.a2a.spec.AgentExtension; import io.a2a.spec.AgentInterface; import io.a2a.spec.Artifact; -import io.a2a.spec.ExtendedAgentCardNotConfiguredError; -import io.a2a.spec.ExtensionSupportRequiredError; -import io.a2a.spec.VersionNotSupportedError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.Event; +import io.a2a.spec.ExtendedAgentCardNotConfiguredError; +import io.a2a.spec.ExtensionSupportRequiredError; import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; -import io.a2a.spec.ListTasksParams; import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.ListTasksParams; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; @@ -78,6 +80,7 @@ import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.a2a.spec.UnsupportedOperationError; +import io.a2a.spec.VersionNotSupportedError; import mutiny.zero.ZeroPublisher; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; @@ -92,7 +95,7 @@ public class JSONRPCHandlerTest extends AbstractA2ARequestHandlerTest { @Test public void testOnGetTaskSuccess() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.id())); GetTaskResponse response = handler.onGetTask(request, callContext); assertEquals(request.getId(), response.getId()); @@ -113,7 +116,7 @@ public void testOnGetTaskNotFound() throws Exception { @Test public void testOnCancelTaskSuccess() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { // We need to cancel the task or the EventConsumer never finds a 'final' event. @@ -138,7 +141,7 @@ public void testOnCancelTaskSuccess() throws Exception { @Test public void testOnCancelTaskNotSupported() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { throw new UnsupportedOperationError(); @@ -174,42 +177,13 @@ public void testOnMessageNewMessageSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - Mockito.doCallRealMethod().when(mock).createAgentRunnableDoneCallback(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - } - @Test public void testOnMessageNewMessageWithExistingTaskSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getMessage()); }; @@ -220,38 +194,9 @@ public void testOnMessageNewMessageWithExistingTaskSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageWithExistingTaskSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - - } - @Test public void testOnMessageError() { // See testMessageOnErrorMocks() for a test more similar to the Python implementation, using mocks for @@ -352,9 +297,11 @@ public void onComplete() { @Test public void testOnMessageStreamNewMessageMultipleEventsSuccess() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + // We'll verify persistence by checking TaskStore after streaming completes JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - // Create multiple events to be sent during streaming + // Create multiple events to be sent during streaming Task taskEvent = Task.builder(MINIMAL_TASK) .status(new TaskStatus(TaskState.WORKING)) .build(); @@ -429,8 +376,8 @@ public void onComplete() { } }); - // Wait for all events to be received - Assertions.assertTrue(latch.await(2, TimeUnit.SECONDS), + // Wait for all events to be received (increased timeout for async processing) + assertTrue(latch.await(10, TimeUnit.SECONDS), "Expected to receive 3 events within timeout"); // Assert no error occurred during streaming @@ -456,6 +403,17 @@ public void onComplete() { "Third event should be a TaskStatusUpdateEvent"); assertEquals(MINIMAL_TASK.id(), receivedStatus.taskId()); assertEquals(TaskState.COMPLETED, receivedStatus.status().state()); + + // Verify events were persisted to TaskStore (poll for final state) + for (int i = 0; i < 50; i++) { + Task storedTask = taskStore.get(MINIMAL_TASK.id()); + if (storedTask != null && storedTask.status() != null + && TaskState.COMPLETED.equals(storedTask.status().state())) { + return; // Success - task finalized in TaskStore + } + Thread.sleep(100); + } + fail("Task should have been finalized in TaskStore within timeout"); } @Test @@ -538,7 +496,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception Task task = Task.builder(MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); Message message = Message.builder(MESSAGE) .taskId(task.id()) @@ -583,7 +541,7 @@ public void onComplete() { }); }); - Assertions.assertTrue(latch.await(1, TimeUnit.SECONDS)); + assertTrue(latch.await(1, TimeUnit.SECONDS)); subscriptionRef.get().cancel(); // The Python implementation has several events emitted since it uses mocks. // @@ -608,7 +566,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() { Task task = Task.builder(MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); // This is used to send events from a mock List events = List.of( @@ -682,7 +640,7 @@ public void onComplete() { @Test public void testSetPushNotificationConfigSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); TaskPushNotificationConfig taskPushConfig = new TaskPushNotificationConfig( @@ -704,7 +662,7 @@ public void testSetPushNotificationConfigSuccess() { @Test public void testGetPushNotificationConfigSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -729,112 +687,124 @@ public void testGetPushNotificationConfigSuccess() { @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - List events = List.of( - MINIMAL_TASK, - TaskArtifactUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .artifact(Artifact.builder() - .artifactId("11") - .parts(new TextPart("text")) - .build()) - .build(), - TaskStatusUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .status(new TaskStatus(TaskState.COMPLETED)) - .build()); - - agentExecutorExecute = (context, eventQueue) -> { - // Hardcode the events to send here - for (Event event : events) { - eventQueue.enqueueEvent(event); - } - }; - - TaskPushNotificationConfig config = new TaskPushNotificationConfig( - MINIMAL_TASK.id(), - PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); - - SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); - SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); - assertNull(stpnResponse.getError()); - - Message msg = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .build(); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); - - final List results = Collections.synchronizedList(new ArrayList<>()); - final AtomicReference subscriptionRef = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(6); - httpClient.latch = latch; - - Executors.newSingleThreadExecutor().execute(() -> { - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriptionRef.set(subscription); - subscription.request(1); - } - - @Override - public void onNext(SendStreamingMessageResponse item) { - System.out.println("-> " + item.getResult()); - results.add(item.getResult()); - System.out.println(results); - subscriptionRef.get().request(1); - latch.countDown(); - } - - @Override - public void onError(Throwable throwable) { - subscriptionRef.get().cancel(); - } - - @Override - public void onComplete() { - subscriptionRef.get().cancel(); + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); + taskStore.save(MINIMAL_TASK, false); + + List events = List.of( + MINIMAL_TASK, + TaskArtifactUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .artifact(Artifact.builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + TaskStatusUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .build()); + + + agentExecutorExecute = (context, eventQueue) -> { + // Hardcode the events to send here + for (Event event : events) { + eventQueue.enqueueEvent(event); } + }; + + TaskPushNotificationConfig config = new TaskPushNotificationConfig( + MINIMAL_TASK.id(), + PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); + + SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); + assertNull(stpnResponse.getError()); + + Message msg = Message.builder(MESSAGE) + .taskId(MINIMAL_TASK.id()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); + + final List results = Collections.synchronizedList(new ArrayList<>()); + final AtomicReference subscriptionRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(6); + httpClient.latch = latch; + + Executors.newSingleThreadExecutor().execute(() -> { + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptionRef.set(subscription); + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + System.out.println("-> " + item.getResult()); + results.add(item.getResult()); + System.out.println(results); + subscriptionRef.get().request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscriptionRef.get().cancel(); + } + + @Override + public void onComplete() { + subscriptionRef.get().cancel(); + } + }); }); - }); - Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); - subscriptionRef.get().cancel(); - assertEquals(3, results.size()); - assertEquals(3, httpClient.tasks.size()); - - Task curr = httpClient.tasks.get(0); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); - - curr = httpClient.tasks.get(1); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); - - curr = httpClient.tasks.get(2); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(TaskState.COMPLETED, curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + assertTrue(latch.await(5, TimeUnit.SECONDS)); + + subscriptionRef.get().cancel(); + assertEquals(3, results.size()); + assertEquals(3, httpClient.tasks.size()); + + Task curr = httpClient.tasks.get(0); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); + + curr = httpClient.tasks.get(1); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + + curr = httpClient.tasks.get(2); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(TaskState.COMPLETED, curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test public void testOnResubscribeExistingTaskSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); queueManager.createOrTap(MINIMAL_TASK.id()); agentExecutorExecute = (context, eventQueue) -> { @@ -861,6 +831,7 @@ public void testOnResubscribeExistingTaskSuccess() { CompletableFuture future = new CompletableFuture<>(); List results = new ArrayList<>(); + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); response.subscribe(new Flow.Subscriber<>() { private Flow.Subscription subscription; @@ -873,7 +844,20 @@ public void onSubscribe(Flow.Subscription subscription) { @Override public void onNext(SendStreamingMessageResponse item) { - results.add(item.getResult()); + StreamingEventKind event = item.getResult(); + results.add(event); + + // Per A2A Protocol Spec 3.1.6: ENFORCE that first event is Task + if (!receivedInitialTask.get()) { + assertTrue(event instanceof Task, + "First event on resubscribe MUST be Task (current state), but was: " + event.getClass().getSimpleName()); + receivedInitialTask.set(true); + } else { + // Subsequent events should be the expected type (Message in this case) + assertTrue(event instanceof Message, + "Expected Message after initial Task, but was: " + event.getClass().getSimpleName()); + } + subscription.request(1); } @@ -892,16 +876,15 @@ public void onComplete() { future.join(); - // The Python implementation has several events emitted since it uses mocks. - // - // See testOnMessageStreamNewMessageExistingTaskSuccessMocks() for a test more similar to the Python implementation - assertEquals(1, results.size()); + // Verify we received exactly 2 events and the initial Task was received + assertEquals(2, results.size()); + assertTrue(receivedInitialTask.get(), "Should have received initial Task event"); } @Test public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); queueManager.createOrTap(MINIMAL_TASK.id()); List events = List.of( @@ -1060,7 +1043,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1107,7 +1090,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1115,7 +1098,7 @@ public void onComplete() { public void testPushNotificationsNotSupportedError() { AgentCard card = createAgentCard(true, false, true); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); TaskPushNotificationConfig config = new TaskPushNotificationConfig( @@ -1135,12 +1118,11 @@ public void testPushNotificationsNotSupportedError() { @Test public void testOnGetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); GetTaskPushNotificationConfigRequest request = new GetTaskPushNotificationConfigRequest("id", new GetTaskPushNotificationConfigParams(MINIMAL_TASK.id())); @@ -1154,12 +1136,11 @@ public void testOnGetPushNotificationNoPushNotifierConfig() { @Test public void testOnSetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); TaskPushNotificationConfig config = new TaskPushNotificationConfig( @@ -1246,12 +1227,11 @@ public void testDefaultRequestHandlerWithCustomComponents() { @Test public void testOnMessageSendErrorHandling() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); Message message = Message.builder(MESSAGE) .taskId(MINIMAL_TASK.id()) @@ -1280,7 +1260,7 @@ public void testOnMessageSendErrorHandling() { @Test public void testOnMessageSendTaskIdMismatch() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = ((context, eventQueue) -> { eventQueue.enqueueEvent(MINIMAL_TASK); @@ -1293,16 +1273,17 @@ public void testOnMessageSendTaskIdMismatch() { } @Test - public void testOnMessageStreamTaskIdMismatch() { + public void testOnMessageStreamTaskIdMismatch() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); - agentExecutorExecute = ((context, eventQueue) -> { - eventQueue.enqueueEvent(MINIMAL_TASK); - }); + agentExecutorExecute = ((context, eventQueue) -> { + eventQueue.enqueueEvent(MINIMAL_TASK); + }); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); CompletableFuture future = new CompletableFuture<>(); List results = new ArrayList<>(); @@ -1347,7 +1328,7 @@ public void onComplete() { @Test public void testListPushNotificationConfig() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1378,7 +1359,7 @@ public void testListPushNotificationConfig() { public void testListPushNotificationConfigNotSupported() { AgentCard card = createAgentCard(true, false, true); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1404,10 +1385,9 @@ public void testListPushNotificationConfigNotSupported() { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1442,7 +1422,7 @@ public void testListPushNotificationConfigTaskNotFound() { @Test public void testDeletePushNotificationConfig() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1470,7 +1450,7 @@ public void testDeletePushNotificationConfig() { public void testDeletePushNotificationConfigNotSupported() { AgentCard card = createAgentCard(true, false, true); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1496,10 +1476,10 @@ public void testDeletePushNotificationConfigNotSupported() { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = + DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1593,7 +1573,7 @@ public void onComplete() { }); // The main thread should not be blocked - we should be able to continue immediately - Assertions.assertTrue(streamStarted.await(100, TimeUnit.MILLISECONDS), + assertTrue(streamStarted.await(100, TimeUnit.MILLISECONDS), "Streaming subscription should start quickly without blocking main thread"); // This proves the main thread is not blocked - we can do other work @@ -1603,11 +1583,11 @@ public void onComplete() { mainThreadBlocked.set(false); // If we get here, main thread was not blocked // Wait for the actual event processing to complete - Assertions.assertTrue(eventProcessed.await(2, TimeUnit.SECONDS), + assertTrue(eventProcessed.await(2, TimeUnit.SECONDS), "Event should be processed within reasonable time"); // Verify we received the event and main thread was not blocked - Assertions.assertTrue(eventReceived.get(), "Should have received streaming event"); + assertTrue(eventReceived.get(), "Should have received streaming event"); Assertions.assertFalse(mainThreadBlocked.get(), "Main thread should not have been blocked"); } @@ -1646,7 +1626,7 @@ public void testExtensionSupportRequiredErrorOnMessageSend() { SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(ExtensionSupportRequiredError.class, response.getError()); - Assertions.assertTrue(response.getError().getMessage().contains("https://example.com/test-extension")); + assertTrue(response.getError().getMessage().contains("https://example.com/test-extension")); assertNull(response.getResult()); } @@ -1715,7 +1695,7 @@ public void onComplete() { assertEquals(1, results.size()); assertInstanceOf(ExtensionSupportRequiredError.class, results.get(0).getError()); - Assertions.assertTrue(results.get(0).getError().getMessage().contains("https://example.com/streaming-extension")); + assertTrue(results.get(0).getError().getMessage().contains("https://example.com/streaming-extension")); assertNull(results.get(0).getResult()); } @@ -1805,7 +1785,7 @@ public void testVersionNotSupportedErrorOnMessageSend() { SendMessageResponse response = handler.onMessageSend(request, contextWithVersion); assertInstanceOf(VersionNotSupportedError.class, response.getError()); - Assertions.assertTrue(response.getError().getMessage().contains("2.0")); + assertTrue(response.getError().getMessage().contains("2.0")); assertNull(response.getResult()); } @@ -1876,12 +1856,12 @@ public void onComplete() { }); // Wait for async processing - Assertions.assertTrue(latch.await(2, TimeUnit.SECONDS), "Expected to receive error event within timeout"); + assertTrue(latch.await(2, TimeUnit.SECONDS), "Expected to receive error event within timeout"); assertEquals(1, results.size()); SendStreamingMessageResponse result = results.get(0); assertInstanceOf(VersionNotSupportedError.class, result.getError()); - Assertions.assertTrue(result.getError().getMessage().contains("2.0")); + assertTrue(result.getError().getMessage().contains("2.0")); assertNull(result.getResult()); } diff --git a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java index 3ffb56c5f..3273d6119 100644 --- a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java +++ b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java @@ -399,32 +399,46 @@ private Flow.Publisher convertToSendStreamingMessageResponse( Flow.Publisher publisher) { // We can't use the normal convertingProcessor since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + log.log(Level.FINE, "REST: convertToSendStreamingMessageResponse called, creating ZeroPublisher"); return ZeroPublisher.create(createTubeConfig(), tube -> { + log.log(Level.FINE, "REST: ZeroPublisher tube created, starting CompletableFuture.runAsync"); CompletableFuture.runAsync(() -> { + log.log(Level.FINE, "REST: Inside CompletableFuture, subscribing to EventKind publisher"); publisher.subscribe(new Flow.Subscriber() { Flow.@Nullable Subscription subscription; @Override public void onSubscribe(Flow.Subscription subscription) { + log.log(Level.FINE, "REST: onSubscribe called, storing subscription and requesting first event"); this.subscription = subscription; subscription.request(1); } @Override public void onNext(StreamingEventKind item) { + log.log(Level.FINE, "REST: onNext called with event: {0}", item.getClass().getSimpleName()); try { String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item)); + log.log(Level.FINE, "REST: Converted to JSON, sending via tube: {0}", payload.substring(0, Math.min(100, payload.length()))); tube.send(payload); + log.log(Level.FINE, "REST: tube.send() completed, requesting next event from EventConsumer"); + // Request next event from EventConsumer (Chain 1: EventConsumer → RestHandler) + // This is safe because ZeroPublisher buffers items + // Chain 2 (ZeroPublisher → MultiSseSupport) controls actual delivery via request(1) in onWriteDone() if (subscription != null) { subscription.request(1); + } else { + log.log(Level.WARNING, "REST: subscription is null in onNext!"); } } catch (InvalidProtocolBufferException ex) { + log.log(Level.SEVERE, "REST: JSON conversion failed", ex); onError(ex); } } @Override public void onError(Throwable throwable) { + log.log(Level.SEVERE, "REST: onError called", throwable); if (throwable instanceof A2AError jsonrpcError) { tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson()); } else { @@ -435,6 +449,7 @@ public void onError(Throwable throwable) { @Override public void onComplete() { + log.log(Level.FINE, "REST: onComplete called, calling tube.complete()"); tube.complete(); } }); diff --git a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java index 7d930415b..db3ff97aa 100644 --- a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java +++ b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java @@ -30,7 +30,7 @@ public class RestHandlerTest extends AbstractA2ARequestHandlerTest { @Test public void testGetTaskSuccess() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.getTask(MINIMAL_TASK.id(), 0, "", callContext); @@ -59,7 +59,7 @@ public void testGetTaskNotFound() { @Test public void testListTasksStatusWireString() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.listTasks(null, "submitted", null, null, null, null, null, "", callContext); @@ -162,7 +162,7 @@ public void testSendMessageEmptyBody() { @Test public void testCancelTaskSuccess() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { // We need to cancel the task or the EventConsumer never finds a 'final' event. @@ -246,7 +246,7 @@ public void testSendStreamingMessageNotSupported() { @Test public void testPushNotificationConfigSuccess() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); String requestBody = """ { @@ -293,7 +293,7 @@ public void testPushNotificationConfigNotSupported() { @Test public void testGetPushNotificationConfig() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // First, create a push notification config String createRequestBody = """ @@ -322,7 +322,7 @@ public void testGetPushNotificationConfig() { @Test public void testDeletePushNotificationConfig() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.deleteTaskPushNotificationConfiguration(MINIMAL_TASK.id(), "default-config-id", "", callContext); Assertions.assertEquals(204, response.getStatusCode()); } @@ -330,7 +330,7 @@ public void testDeletePushNotificationConfig() { @Test public void testListPushNotificationConfigs() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.listTaskPushNotificationConfigurations(MINIMAL_TASK.id(), 0, "", "", callContext); @@ -894,7 +894,7 @@ public void testListTasksNegativeTimestampReturns422() { @Test public void testListTasksUnixMillisecondsTimestamp() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // Unix milliseconds timestamp should be accepted String timestampMillis = String.valueOf(System.currentTimeMillis() - 10000); @@ -909,7 +909,7 @@ public void testListTasksUnixMillisecondsTimestamp() { @Test public void testListTasksProtobufEnumStatus() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // Protobuf enum format (TASK_STATE_SUBMITTED) should be accepted RestHandler.HTTPRestResponse response = handler.listTasks(null, "TASK_STATE_SUBMITTED", null, null, @@ -923,7 +923,7 @@ public void testListTasksProtobufEnumStatus() { @Test public void testListTasksEnumConstantStatus() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // Enum constant format (SUBMITTED) should be accepted RestHandler.HTTPRestResponse response = handler.listTasks(null, "SUBMITTED", null, null,