From 46f7691077dd2ef7097a83fe048bbac76f2bef7b Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Mon, 2 Feb 2026 19:08:07 +0000 Subject: [PATCH 1/8] feat: Implement MainEventBus architecture and resolve multi-instance replication race conditions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce centralized event bus with single background processor to guarantee events persist to TaskStore before distribution to clients. This eliminates race conditions when multiple concurrent requests update the same task. **Key Components:** - MainEventBus: Shared BlockingQueue for all MainQueue events - MainEventBusProcessor: Single background thread for ordered processing - Processing sequence: TaskStore.save() → PushNotifications → distributeToChildren() - TaskStateProvider interface: Task state queries for queue lifecycle management **Event Flow:** ``` AgentExecutor → MainQueue → MainEventBus → MainEventBusProcessor → TaskStore (persist first) → Push Notifications → ChildQueues → Clients ``` **Benefits:** - Events persist before clients receive them (no stale data) - Serial processing prevents concurrent TaskStore updates - Platform-agnostic ChildQueue synchronization (works across gRPC/JSONRPC/REST) - Clean separation: MainQueue (no local queue) vs ChildQueue (local queue for clients) Implement two-level protection to keep MainQueues open for fire-and-forget tasks and late resubscriptions while cleaning up finalized tasks: **Level 1** - Cleanup Callback: Check TaskStateProvider.isTaskFinalized() before removing queue from QueueManager map **Level 2** - Auto-Close Prevention: MainQueue.childClosing() checks finality before closing when last ChildQueue disconnects **Result:** Non-final tasks keep queues open for resubscription; finalized tasks clean up immediately ReplicatedQueueManager.onTaskFinalized() sent full Task objects to remote instances via Kafka, while local instances sent TaskStatusUpdateEvent. Client auto-close logic only checked for TaskStatusUpdateEvent.isFinal(), causing connection leaks on remote instances. **ReplicatedQueueManager.onTaskFinalized():** Convert Task to TaskStatusUpdateEvent before sending to Kafka, ensuring consistent event types across all instances **EventConsumer:** Add 50ms delay before tube.complete() to allow SSE buffer flush in replicated scenarios where events arrive via Kafka with timing variations **SSEEventListener (JSONRPC):** Check both TaskStatusUpdateEvent.isFinal() and Task.status().state().isFinal() for auto-close **RestSSEEventListener (REST):** Add complete auto-close logic (was missing entirely) **Benefits:** - Handles late subscriptions to completed tasks gracefully - Prevents connection leaks in all scenarios - Consistent behavior across JSONRPC and REST transports - Defensive programming for edge cases **MultiInstanceReplicationTest:** - Add TaskEvent handling and container log dumping on failure - Verify both APP1 and APP2 receive all events including final states - Test late-arriving replicated events and poison pill ordering **Integration Tests:** - EventConsumerTest: Grace period mechanism for replicated scenarios - ReplicatedQueueManagerTest: Event type conversion validation - All existing tests updated for MainEventBus architecture **EventQueue.Builder:** Now requires MainEventBus parameter (validates non-null) **QueueManager Implementations:** Must handle TaskStateProvider for lifecycle checks Existing code continues to work - InMemoryQueueManager and ReplicatedQueueManager automatically inject MainEventBus via CDI. Custom QueueManager implementations should inject MainEventBus and pass to EventQueue.Builder. **Core Architecture:** - server-common/.../events/MainEventBus.java (new) - server-common/.../events/MainEventBusProcessor.java (new) - server-common/.../events/EventQueue.java (requires MainEventBus) - server-common/.../events/InMemoryQueueManager.java (queue lifecycle) **Replication:** - extras/queue-manager-replicated/core/.../ReplicatedQueueManager.java (event conversion) - extras/queue-manager-replicated/core/.../ReplicatedEventQueueItem.java (Task support) **Client Transports:** - client/transport/jsonrpc/.../SSEEventListener.java (enhanced auto-close) - client/transport/rest/.../RestSSEEventListener.java (add auto-close) **Event Processing:** - server-common/.../events/EventConsumer.java (grace period + buffer flush) - server-common/.../requesthandlers/DefaultRequestHandler.java (MainEventBus integration) **Task Management:** - server-common/.../tasks/TaskStore.java (TaskStateProvider interface) - extras/task-store-database-jpa/.../JpaDatabaseTaskStore.java (implement TaskStateProvider) ✅ All unit tests pass (150+ tests) ✅ MultiInstanceReplicationTest passes (both instances receive all events) ✅ TCK tests pass (no connection leaks) ✅ Integration tests pass (EventConsumer, QueueManager, TaskStore) --- .../jsonrpc/sse/SSEEventListener.java | 20 +- .../rest/sse/RestSSEEventListener.java | 24 +- examples/cloud-deployment/scripts/deploy.sh | 16 + .../common/events/TaskFinalizedEvent.java | 14 +- ...paDatabasePushNotificationConfigStore.java | 1 + .../core/ReplicatedEventQueueItem.java | 10 + .../core/ReplicatedQueueManager.java | 103 +- .../core/ReplicatedQueueManagerTest.java | 306 +++-- .../io/a2a/server/events/EventQueueUtil.java | 11 + .../src/main/resources/application.properties | 2 + .../src/main/resources/application.properties | 2 + .../MultiInstanceReplicationTest.java | 97 +- .../KafkaReplicationIntegrationTest.java | 12 + .../database/jpa/JpaDatabaseTaskStore.java | 16 +- .../jpa/JpaDatabaseTaskStoreTest.java | 64 +- .../server/apps/quarkus/A2AServerRoutes.java | 120 +- .../src/test/resources/application.properties | 5 + .../server/rest/quarkus/A2AServerRoutes.java | 144 ++- .../java/io/a2a/server/ServerCallContext.java | 61 + .../io/a2a/server/events/EventConsumer.java | 80 ++ .../java/io/a2a/server/events/EventQueue.java | 427 ++++--- .../server/events/InMemoryQueueManager.java | 32 +- .../io/a2a/server/events/MainEventBus.java | 42 + .../server/events/MainEventBusContext.java | 11 + .../server/events/MainEventBusProcessor.java | 385 +++++++ .../events/MainEventBusProcessorCallback.java | 66 ++ .../MainEventBusProcessorInitializer.java | 43 + .../io/a2a/server/events/QueueManager.java | 26 +- .../DefaultRequestHandler.java | 506 +++++---- .../a2a/server/tasks/InMemoryTaskStore.java | 3 +- .../io/a2a/server/tasks/ResultAggregator.java | 128 ++- .../java/io/a2a/server/tasks/TaskManager.java | 38 +- .../java/io/a2a/server/tasks/TaskStore.java | 5 +- .../util/async/AsyncExecutorProducer.java | 57 +- .../async/EventConsumerExecutorProducer.java | 93 ++ .../io/a2a/server/util/sse/SseFormatter.java | 136 +++ .../io/a2a/server/util/sse/package-info.java | 11 + .../META-INF/a2a-defaults.properties | 4 + .../a2a/server/events/EventConsumerTest.java | 100 +- .../io/a2a/server/events/EventQueueTest.java | 227 ++-- .../io/a2a/server/events/EventQueueUtil.java | 37 +- .../events/InMemoryQueueManagerTest.java | 34 +- .../AbstractA2ARequestHandlerTest.java | 23 +- .../DefaultRequestHandlerTest.java | 1001 ----------------- .../server/tasks/InMemoryTaskStoreTest.java | 49 - .../server/tasks/ResultAggregatorTest.java | 81 +- .../io/a2a/server/tasks/TaskManagerTest.java | 91 +- .../io/a2a/server/tasks/TaskUpdaterTest.java | 58 +- tck/src/main/resources/application.properties | 1 + .../apps/common/AbstractA2AServerTest.java | 63 ++ .../a2a/server/apps/common/TestUtilsBean.java | 2 +- .../transport/grpc/handler/GrpcHandler.java | 24 +- .../grpc/handler/GrpcHandlerTest.java | 83 +- .../jsonrpc/handler/JSONRPCHandlerTest.java | 416 ++++--- .../transport/rest/handler/RestHandler.java | 15 + .../rest/handler/RestHandlerTest.java | 20 +- 56 files changed, 3236 insertions(+), 2210 deletions(-) create mode 100644 extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBus.java create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java create mode 100644 server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java create mode 100644 server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java create mode 100644 server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java create mode 100644 server-common/src/main/java/io/a2a/server/util/sse/package-info.java diff --git a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java index 33025ed5e..3a24f5145 100644 --- a/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java +++ b/client/transport/jsonrpc/src/main/java/io/a2a/client/transport/jsonrpc/sse/SSEEventListener.java @@ -10,6 +10,8 @@ import io.a2a.jsonrpc.common.json.JsonProcessingException; import io.a2a.spec.A2AError; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; @@ -64,11 +66,23 @@ private void handleMessage(String message, @Nullable Future future) { StreamingEventKind event = ProtoUtils.FromProto.streamingEventKind(response); eventHandler.accept(event); - if (event instanceof TaskStatusUpdateEvent && ((TaskStatusUpdateEvent) event).isFinal()) { - if (future != null) { - future.cancel(true); // close SSE channel + + // Client-side auto-close on final events to prevent connection leaks + // Handles both TaskStatusUpdateEvent and Task objects with final states + // This covers late subscriptions to completed tasks and ensures no connection leaks + boolean shouldClose = false; + if (event instanceof TaskStatusUpdateEvent tue && tue.isFinal()) { + shouldClose = true; + } else if (event instanceof Task task) { + TaskState state = task.status().state(); + if (state.isFinal()) { + shouldClose = true; } } + + if (shouldClose && future != null) { + future.cancel(true); // close SSE channel + } } catch (A2AError error) { if (errorHandler != null) { errorHandler.accept(error); diff --git a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java index ec74d2fbc..85e604da3 100644 --- a/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java +++ b/client/transport/rest/src/main/java/io/a2a/client/transport/rest/sse/RestSSEEventListener.java @@ -10,6 +10,9 @@ import io.a2a.grpc.StreamResponse; import io.a2a.grpc.utils.ProtoUtils; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; +import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; public class RestSSEEventListener { @@ -29,7 +32,7 @@ public void onMessage(String message, @Nullable Future completableFuture) log.fine("Streaming message received: " + message); io.a2a.grpc.StreamResponse.Builder builder = io.a2a.grpc.StreamResponse.newBuilder(); JsonFormat.parser().merge(message, builder); - handleMessage(builder.build()); + handleMessage(builder.build(), completableFuture); } catch (InvalidProtocolBufferException e) { errorHandler.accept(RestErrorMapper.mapRestError(message, 500)); } @@ -44,7 +47,7 @@ public void onError(Throwable throwable, @Nullable Future future) { } } - private void handleMessage(StreamResponse response) { + private void handleMessage(StreamResponse response, @Nullable Future future) { StreamingEventKind event; switch (response.getPayloadCase()) { case MESSAGE -> @@ -62,6 +65,23 @@ private void handleMessage(StreamResponse response) { } } eventHandler.accept(event); + + // Client-side auto-close on final events to prevent connection leaks + // Handles both TaskStatusUpdateEvent and Task objects with final states + // This covers late subscriptions to completed tasks and ensures no connection leaks + boolean shouldClose = false; + if (event instanceof TaskStatusUpdateEvent tue && tue.isFinal()) { + shouldClose = true; + } else if (event instanceof Task task) { + TaskState state = task.status().state(); + if (state.isFinal()) { + shouldClose = true; + } + } + + if (shouldClose && future != null) { + future.cancel(true); // close SSE channel + } } } diff --git a/examples/cloud-deployment/scripts/deploy.sh b/examples/cloud-deployment/scripts/deploy.sh index e267f3302..fff2a6061 100755 --- a/examples/cloud-deployment/scripts/deploy.sh +++ b/examples/cloud-deployment/scripts/deploy.sh @@ -212,6 +212,22 @@ echo "" echo "Deploying PostgreSQL..." kubectl apply -f ../k8s/01-postgres.yaml echo "Waiting for PostgreSQL to be ready..." + +# Wait for pod to be created (StatefulSet takes time to create pod) +for i in {1..30}; do + if kubectl get pod -l app=postgres -n a2a-demo 2>/dev/null | grep -q postgres; then + echo "PostgreSQL pod found, waiting for ready state..." + break + fi + if [ $i -eq 30 ]; then + echo -e "${RED}ERROR: PostgreSQL pod not created after 30 seconds${NC}" + kubectl get statefulset -n a2a-demo + exit 1 + fi + sleep 1 +done + +# Now wait for pod to be ready kubectl wait --for=condition=Ready pod -l app=postgres -n a2a-demo --timeout=120s echo -e "${GREEN}✓ PostgreSQL deployed${NC}" diff --git a/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java b/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java index 8c5f59348..0c35bad7a 100644 --- a/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java +++ b/extras/common/src/main/java/io/a2a/extras/common/events/TaskFinalizedEvent.java @@ -5,22 +5,28 @@ * This event is fired AFTER the database transaction commits, making it safe for downstream * components to assume the task is durably stored. * - *

Used by the replicated queue manager to send poison pill events after ensuring - * the final task state is committed to the database, eliminating race conditions. + *

Used by the replicated queue manager to send the final task state before the poison pill, + * ensuring correct event ordering across instances and eliminating race conditions. */ public class TaskFinalizedEvent { private final String taskId; + private final Object task; // Task type from io.a2a.spec - using Object to avoid dependency - public TaskFinalizedEvent(String taskId) { + public TaskFinalizedEvent(String taskId, Object task) { this.taskId = taskId; + this.task = task; } public String getTaskId() { return taskId; } + public Object getTask() { + return task; + } + @Override public String toString() { - return "TaskFinalizedEvent{taskId='" + taskId + "'}"; + return "TaskFinalizedEvent{taskId='" + taskId + "', task=" + task + "}"; } } diff --git a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java index 36245e277..5049bc9a4 100644 --- a/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java +++ b/extras/push-notification-config-store-database-jpa/src/main/java/io/a2a/extras/pushnotificationconfigstore/database/jpa/JpaDatabasePushNotificationConfigStore.java @@ -164,4 +164,5 @@ public void deleteInfo(String taskId, String configId) { taskId, configId); } } + } diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java index 87c10fb4e..206e07f03 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedEventQueueItem.java @@ -149,6 +149,16 @@ public void setClosedEvent(boolean closedEvent) { } } + /** + * Check if this event is a Task event. + * Task events should always be processed even for inactive tasks, + * as they carry the final task state. + * @return true if this is a Task event + */ + public boolean isTaskEvent() { + return event instanceof io.a2a.spec.Task; + } + @Override public String toString() { return "ReplicatedEventQueueItem{" + diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java index 586ab11a7..d232dea98 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java @@ -13,8 +13,10 @@ import io.a2a.server.events.EventQueueFactory; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; import io.a2a.server.events.QueueManager; import io.a2a.server.tasks.TaskStateProvider; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,10 +47,12 @@ protected ReplicatedQueueManager() { } @Inject - public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, TaskStateProvider taskStateProvider) { + public ReplicatedQueueManager(ReplicationStrategy replicationStrategy, + TaskStateProvider taskStateProvider, + MainEventBus mainEventBus) { this.replicationStrategy = replicationStrategy; this.taskStateProvider = taskStateProvider; - this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider); + this.delegate = new InMemoryQueueManager(new ReplicatingEventQueueFactory(), taskStateProvider, mainEventBus); } @@ -77,8 +81,7 @@ public void close(String taskId) { @Override public EventQueue createOrTap(String taskId) { - EventQueue queue = delegate.createOrTap(taskId); - return queue; + return delegate.createOrTap(taskId); } @Override @@ -87,9 +90,11 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep } public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent) { - // Check if task is still active before processing replicated event (unless it's a QueueClosedEvent) - // QueueClosedEvent should always be processed to terminate streams, even for inactive tasks + // Check if task is still active before processing replicated event + // Always allow QueueClosedEvent and Task events (they carry final state) + // Skip other event types for inactive tasks to prevent queue creation for expired tasks if (!replicatedEvent.isClosedEvent() + && !replicatedEvent.isTaskEvent() && !taskStateProvider.isTaskActive(replicatedEvent.getTaskId())) { // Task is no longer active - skip processing this replicated event // This prevents creating queues for tasks that have been finalized beyond the grace period @@ -97,38 +102,69 @@ public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent return; } - // Get or create a ChildQueue for this task (creates MainQueue if it doesn't exist) - EventQueue childQueue = delegate.createOrTap(replicatedEvent.getTaskId()); - - try { - // Get the MainQueue to enqueue the replicated event item - // We must use enqueueItem (not enqueueEvent) to preserve the isReplicated() flag - // and avoid triggering the replication hook again (which would cause a replication loop) - EventQueue mainQueue = delegate.get(replicatedEvent.getTaskId()); - if (mainQueue != null) { - mainQueue.enqueueItem(replicatedEvent); - } else { - LOGGER.warn("MainQueue not found for task {}, cannot enqueue replicated event. This may happen if the queue was already cleaned up.", - replicatedEvent.getTaskId()); - } - } finally { - // Close the temporary ChildQueue to prevent leaks - // The MainQueue remains open for other consumers - childQueue.close(); + // Get the MainQueue to enqueue the replicated event item + // We must use enqueueItem (not enqueueEvent) to preserve the isReplicated() flag + // and avoid triggering the replication hook again (which would cause a replication loop) + // + // IMPORTANT: We must NOT create a ChildQueue here! Creating and immediately closing + // a ChildQueue means there are zero children when MainEventBusProcessor distributes + // the event. Existing ChildQueues (from active client subscriptions) will receive + // the event when MainEventBusProcessor distributes it to all children. + // + // If MainQueue doesn't exist, create it. This handles late-arriving replicated events + // for tasks that were created on another instance. + EventQueue mainQueue = delegate.get(replicatedEvent.getTaskId()); + if (mainQueue == null) { + // MainQueue doesn't exist - create it by calling createOrTap and then get the MainQueue + // Replicated events should always have real task IDs (not temp IDs) because + // replication now happens AFTER TaskStore persistence in MainEventBusProcessor + LOGGER.debug("Creating MainQueue for replicated event on task {}", replicatedEvent.getTaskId()); + delegate.createOrTap(replicatedEvent.getTaskId()); + mainQueue = delegate.get(replicatedEvent.getTaskId()); + } + + if (mainQueue != null) { + mainQueue.enqueueItem(replicatedEvent); + } else { + LOGGER.warn("MainQueue not found for task {}, cannot enqueue replicated event. This may happen if the queue was already cleaned up.", + replicatedEvent.getTaskId()); } } /** * Observes task finalization events fired AFTER database transaction commits. - * This guarantees the task's final state is durably stored before sending the poison pill. + * This guarantees the task's final state is durably stored before replication. + * + * Sends TaskStatusUpdateEvent (not full Task) FIRST, then the poison pill (QueueClosedEvent), + * ensuring correct event ordering across instances and eliminating race conditions where + * the poison pill arrives before the final task state. * - * @param event the task finalized event containing the task ID + * IMPORTANT: We send TaskStatusUpdateEvent instead of full Task to maintain consistency + * with local event distribution. Clients expect TaskStatusUpdateEvent for status changes, + * and sending the full Task causes issues in remote instances where clients don't handle + * bare Task objects the same way they handle TaskStatusUpdateEvent. + * + * @param event the task finalized event containing the task ID and final Task */ public void onTaskFinalized(@Observes(during = TransactionPhase.AFTER_SUCCESS) TaskFinalizedEvent event) { String taskId = event.getTaskId(); - LOGGER.debug("Task {} finalized - sending poison pill (QueueClosedEvent) after transaction commit", taskId); + io.a2a.spec.Task finalTask = (io.a2a.spec.Task) event.getTask(); // Cast from Object + + LOGGER.debug("Task {} finalized - sending TaskStatusUpdateEvent then poison pill (QueueClosedEvent) after transaction commit", taskId); + + // Convert final Task to TaskStatusUpdateEvent to match local event distribution + // This ensures remote instances receive the same event type as local instances + io.a2a.spec.TaskStatusUpdateEvent finalStatusEvent = io.a2a.spec.TaskStatusUpdateEvent.builder() + .taskId(taskId) + .contextId(finalTask.contextId()) + .status(finalTask.status()) + .isFinal(true) + .build(); + + // Send TaskStatusUpdateEvent FIRST to ensure it arrives before poison pill + replicationStrategy.send(taskId, finalStatusEvent); - // Send poison pill directly via replication strategy + // Then send poison pill // The transaction has committed, so the final state is guaranteed to be in the database io.a2a.server.events.QueueClosedEvent closedEvent = new io.a2a.server.events.QueueClosedEvent(taskId); replicationStrategy.send(taskId, closedEvent); @@ -152,12 +188,11 @@ public EventQueue.EventQueueBuilder builder(String taskId) { // which sends the QueueClosedEvent after the database transaction commits. // This ensures proper ordering and transactional guarantees. - // Return the builder with callbacks - return delegate.getEventQueueBuilder(taskId) - .taskId(taskId) - .hook(new ReplicationHook(taskId)) - .addOnCloseCallback(delegate.getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Call createBaseEventQueueBuilder() directly to avoid infinite recursion + // (getEventQueueBuilder() would delegate back to this factory, creating a loop) + // The base builder already includes: taskId, cleanup callback, taskStateProvider, mainEventBus + return delegate.createBaseEventQueueBuilder(taskId) + .hook(new ReplicationHook(taskId)); } } diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java index 43571cd30..7f53bfbfd 100644 --- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java @@ -22,12 +22,19 @@ import io.a2a.server.events.EventQueueClosedException; import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueTestHelper; +import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.events.QueueClosedEvent; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.Event; import io.a2a.spec.StreamingEventKind; +import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -35,10 +42,27 @@ class ReplicatedQueueManagerTest { private ReplicatedQueueManager queueManager; private StreamingEventKind testEvent; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach void setUp() { - queueManager = new ReplicatedQueueManager(new NoOpReplicationStrategy(), new MockTaskStateProvider(true)); + // Create MainEventBus first + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + + // Create QueueManager before MainEventBusProcessor (processor needs it as parameter) + queueManager = new ReplicatedQueueManager( + new NoOpReplicationStrategy(), + new MockTaskStateProvider(true), + mainEventBus + ); + + // Create MainEventBusProcessor with QueueManager + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + testEvent = TaskStatusUpdateEvent.builder() .taskId("test-task") .contextId("test-context") @@ -47,25 +71,82 @@ void setUp() { .build(); } + /** + * Helper to create a test event with the specified taskId. + * This ensures taskId consistency between queue creation and event creation. + */ + private TaskStatusUpdateEvent createEventForTask(String taskId) { + return TaskStatusUpdateEvent.builder() + .taskId(taskId) + .contextId("test-context") + .status(new TaskStatus(TaskState.SUBMITTED)) + .isFinal(false) + .build(); + } + + @AfterEach + void tearDown() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + mainEventBusProcessor = null; + mainEventBus = null; + queueManager = null; + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test void testReplicationStrategyTriggeredOnNormalEnqueue() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-1"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process the event and trigger replication + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(1, strategy.getCallCount()); assertEquals(taskId, strategy.getLastTaskId()); - assertEquals(testEvent, strategy.getLastEvent()); + assertEquals(event, strategy.getLastEvent()); } @Test void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-2"; EventQueue queue = queueManager.createOrTap(taskId); @@ -79,17 +160,19 @@ void testReplicationStrategyNotTriggeredOnReplicatedEvent() throws InterruptedEx @Test void testReplicationStrategyWithCountingImplementation() throws InterruptedException { CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "test-task-3"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - queue.enqueueEvent(testEvent); - queue.enqueueEvent(testEvent); + // Wait for MainEventBusProcessor to process each event + waitForEventProcessing(() -> queue.enqueueEvent(event)); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(2, countingStrategy.getCallCount()); assertEquals(taskId, countingStrategy.getLastTaskId()); - assertEquals(testEvent, countingStrategy.getLastEvent()); + assertEquals(event, countingStrategy.getLastEvent()); ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); queueManager.onReplicatedEvent(replicatedEvent); @@ -100,46 +183,45 @@ void testReplicationStrategyWithCountingImplementation() throws InterruptedExcep @Test void testReplicatedEventDeliveredToCorrectQueue() throws InterruptedException { String taskId = "test-task-4"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId EventQueue queue = queueManager.createOrTap(taskId); - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - queueManager.onReplicatedEvent(replicatedEvent); + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent); + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedEvent)); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent); } @Test void testReplicatedEventCreatesQueueIfNeeded() throws InterruptedException { String taskId = "non-existent-task"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId // Verify no queue exists initially assertNull(queueManager.get(taskId)); - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - - // Process the replicated event - assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent)); - - // Verify that a queue was created and the event was enqueued - EventQueue queue = queueManager.get(taskId); - assertNotNull(queue, "Queue should be created when processing replicated event for non-existent task"); - - // Verify the event was enqueued by dequeuing it - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent, "The replicated event should be enqueued in the newly created queue"); + // Create a ChildQueue BEFORE processing the replicated event + // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event + EventQueue childQueue = queueManager.createOrTap(taskId); + assertNotNull(childQueue, "ChildQueue should be created"); + + // Verify MainQueue was created + EventQueue mainQueue = queueManager.get(taskId); + assertNotNull(mainQueue, "MainQueue should exist after createOrTap"); + + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); + + // Process the replicated event and wait for distribution + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(childQueue, () -> { + assertDoesNotThrow(() -> queueManager.onReplicatedEvent(replicatedEvent)); + }); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent, "The replicated event should be enqueued in the newly created queue"); } @Test @@ -170,17 +252,18 @@ void testBasicQueueManagerFunctionality() throws InterruptedException { void testQueueToTaskIdMappingMaintained() throws InterruptedException { String taskId = "test-task-6"; CountingReplicationStrategy countingStrategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(countingStrategy, new MockTaskStateProvider(true), mainEventBus); + TaskStatusUpdateEvent event = createEventForTask(taskId); EventQueue queue = queueManager.createOrTap(taskId); - queue.enqueueEvent(testEvent); + waitForEventProcessing(() -> queue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); queueManager.close(taskId); // Task is active, so NO poison pill is sent EventQueue newQueue = queueManager.createOrTap(taskId); - newQueue.enqueueEvent(testEvent); + waitForEventProcessing(() -> newQueue.enqueueEvent(event)); assertEquals(taskId, countingStrategy.getLastTaskId()); // 2 replication calls: 1 testEvent, 1 testEvent (no QueueClosedEvent because task is active) @@ -217,17 +300,32 @@ void testReplicatedEventJsonSerialization() throws Exception { @Test void testParallelReplicationBehavior() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "parallel-test-task"; EventQueue queue = queueManager.createOrTap(taskId); int numThreads = 10; int eventsPerThread = 5; + int expectedEventCount = (numThreads / 2) * eventsPerThread; // Only normal enqueues ExecutorService executor = Executors.newFixedThreadPool(numThreads); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(numThreads); + // Set up callback to wait for all events to be processed by MainEventBusProcessor + CountDownLatch processingLatch = new CountDownLatch(expectedEventCount); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String tid, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String tid) { + // Not needed for this test + } + }); + // Launch threads that will enqueue events normally (should trigger replication) for (int i = 0; i < numThreads / 2; i++) { final int threadId = i; @@ -236,7 +334,7 @@ void testParallelReplicationBehavior() throws InterruptedException { startLatch.await(); for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("normal-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.WORKING)) .isFinal(false) @@ -260,7 +358,7 @@ void testParallelReplicationBehavior() throws InterruptedException { startLatch.await(); for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() - .taskId("replicated-" + threadId + "-" + j) + .taskId(taskId) // Use same taskId as queue .contextId("test-context") .status(new TaskStatus(TaskState.COMPLETED)) .isFinal(true) @@ -286,6 +384,14 @@ void testParallelReplicationBehavior() throws InterruptedException { executor.shutdown(); assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), "Executor should shutdown within 5 seconds"); + // Wait for MainEventBusProcessor to process all events + try { + assertTrue(processingLatch.await(10, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed all events within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + // Only the normal enqueue operations should have triggered replication // numThreads/2 threads * eventsPerThread events each = total expected replication calls int expectedReplicationCalls = (numThreads / 2) * eventsPerThread; @@ -297,7 +403,7 @@ void testParallelReplicationBehavior() throws InterruptedException { void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException { // Create a task state provider that returns false (task is inactive) MockTaskStateProvider stateProvider = new MockTaskStateProvider(false); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "inactive-task"; @@ -316,30 +422,32 @@ void testReplicatedEventSkippedWhenTaskInactive() throws InterruptedException { void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException { // Create a task state provider that returns true (task is active) MockTaskStateProvider stateProvider = new MockTaskStateProvider(true); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "active-task"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId // Verify no queue exists initially assertNull(queueManager.get(taskId)); - // Process a replicated event for an active task - ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, testEvent); - queueManager.onReplicatedEvent(replicatedEvent); + // Create a ChildQueue BEFORE processing the replicated event + // This ensures the ChildQueue exists when MainEventBusProcessor distributes the event + EventQueue childQueue = queueManager.createOrTap(taskId); + assertNotNull(childQueue, "ChildQueue should be created"); - // Queue should be created and event should be enqueued - EventQueue queue = queueManager.get(taskId); - assertNotNull(queue, "Queue should be created for active task"); + // Verify MainQueue was created + EventQueue mainQueue = queueManager.get(taskId); + assertNotNull(mainQueue, "MainQueue should exist after createOrTap"); - // Verify the event was enqueued - Event dequeuedEvent; - try { - dequeuedEvent = queue.dequeueEventItem(100).getEvent(); - } catch (EventQueueClosedException e) { - fail("Queue should not be closed"); - return; - } - assertEquals(testEvent, dequeuedEvent, "Event should be enqueued for active task"); + // Process a replicated event for an active task + ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, eventForTask); + + // Verify the event was enqueued and distributed to our ChildQueue + // Use callback to wait for event processing + EventQueueItem item = dequeueEventWithRetry(childQueue, () -> queueManager.onReplicatedEvent(replicatedEvent)); + assertNotNull(item, "Event should be available in queue"); + Event dequeuedEvent = item.getEvent(); + assertEquals(eventForTask, dequeuedEvent, "Event should be enqueued for active task"); } @@ -347,7 +455,7 @@ void testReplicatedEventProcessedWhenTaskActive() throws InterruptedException { void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws InterruptedException { // Create a task state provider that returns true initially MockTaskStateProvider stateProvider = new MockTaskStateProvider(true); - queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider); + queueManager = new ReplicatedQueueManager(new CountingReplicationStrategy(), stateProvider, mainEventBus); String taskId = "task-becomes-inactive"; @@ -387,30 +495,38 @@ void testReplicatedEventToExistingQueueWhenTaskBecomesInactive() throws Interrup @Test void testPoisonPillSentViaTransactionAwareEvent() throws InterruptedException { CountingReplicationStrategy strategy = new CountingReplicationStrategy(); - queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true)); + queueManager = new ReplicatedQueueManager(strategy, new MockTaskStateProvider(true), mainEventBus); String taskId = "poison-pill-test"; EventQueue queue = queueManager.createOrTap(taskId); + TaskStatusUpdateEvent event = createEventForTask(taskId); - // Enqueue a normal event first - queue.enqueueEvent(testEvent); + // Enqueue a normal event first and wait for processing + waitForEventProcessing(() -> queue.enqueueEvent(event)); // In the new architecture, QueueClosedEvent (poison pill) is sent via CDI events // when JpaDatabaseTaskStore.save() persists a final task and the transaction commits // ReplicatedQueueManager.onTaskFinalized() observes AFTER_SUCCESS and sends the poison pill // Simulate the CDI event observer being called (what happens in real execution) - TaskFinalizedEvent taskFinalizedEvent = new TaskFinalizedEvent(taskId); + // Create a final task for the event + Task finalTask = Task.builder() + .id(taskId) + .contextId("test-context") + .status(new TaskStatus(TaskState.COMPLETED)) + .build(); + TaskFinalizedEvent taskFinalizedEvent = new TaskFinalizedEvent(taskId, finalTask); // Call the observer method directly (simulating CDI event delivery) queueManager.onTaskFinalized(taskFinalizedEvent); - // Verify that QueueClosedEvent was replicated - // strategy.getCallCount() should be 2: one for testEvent, one for QueueClosedEvent - assertEquals(2, strategy.getCallCount(), "Should have replicated both normal event and QueueClosedEvent"); + // Verify that final Task and QueueClosedEvent were replicated + // strategy.getCallCount() should be 3: testEvent, final Task, then QueueClosedEvent (poison pill) + assertEquals(3, strategy.getCallCount(), "Should have replicated testEvent, final Task, and QueueClosedEvent"); + // Verify the last event is QueueClosedEvent (poison pill) Event lastEvent = strategy.getLastEvent(); - assertTrue(lastEvent instanceof QueueClosedEvent, "Last replicated event should be QueueClosedEvent"); + assertTrue(lastEvent instanceof QueueClosedEvent, "Last replicated event should be QueueClosedEvent (poison pill)"); assertEquals(taskId, ((QueueClosedEvent) lastEvent).getTaskId()); } @@ -451,36 +567,21 @@ void testQueueClosedEventJsonSerialization() throws Exception { @Test void testReplicatedQueueClosedEventTerminatesConsumer() throws InterruptedException { String taskId = "remote-close-test"; + TaskStatusUpdateEvent eventForTask = createEventForTask(taskId); // Use matching taskId EventQueue queue = queueManager.createOrTap(taskId); - // Enqueue a normal event - queue.enqueueEvent(testEvent); - // Simulate receiving QueueClosedEvent from remote node QueueClosedEvent closedEvent = new QueueClosedEvent(taskId); ReplicatedEventQueueItem replicatedClosedEvent = new ReplicatedEventQueueItem(taskId, closedEvent); - queueManager.onReplicatedEvent(replicatedClosedEvent); - // Dequeue the normal event first - EventQueueItem item1; - try { - item1 = queue.dequeueEventItem(100); - } catch (EventQueueClosedException e) { - fail("Should not throw on first dequeue"); - return; - } - assertNotNull(item1); - assertEquals(testEvent, item1.getEvent()); + // Dequeue the normal event first (use callback to wait for async processing) + EventQueueItem item1 = dequeueEventWithRetry(queue, () -> queue.enqueueEvent(eventForTask)); + assertNotNull(item1, "First event should be available"); + assertEquals(eventForTask, item1.getEvent()); - // Next dequeue should get the QueueClosedEvent - EventQueueItem item2; - try { - item2 = queue.dequeueEventItem(100); - } catch (EventQueueClosedException e) { - fail("Should not throw on second dequeue, should return the event"); - return; - } - assertNotNull(item2); + // Next dequeue should get the QueueClosedEvent (use callback to wait for async processing) + EventQueueItem item2 = dequeueEventWithRetry(queue, () -> queueManager.onReplicatedEvent(replicatedClosedEvent)); + assertNotNull(item2, "QueueClosedEvent should be available"); assertTrue(item2.getEvent() instanceof QueueClosedEvent, "Second event should be QueueClosedEvent"); } @@ -539,4 +640,25 @@ public void setActive(boolean active) { this.active = active; } } + + /** + * Helper method to dequeue an event after waiting for MainEventBusProcessor distribution. + * Uses callback-based waiting instead of polling for deterministic synchronization. + * + * @param queue the queue to dequeue from + * @param enqueueAction the action that enqueues the event (triggers event processing) + * @return the dequeued EventQueueItem, or null if queue is closed + */ + private EventQueueItem dequeueEventWithRetry(EventQueue queue, Runnable enqueueAction) throws InterruptedException { + // Wait for event to be processed and distributed + waitForEventProcessing(enqueueAction); + + // Event is now available, dequeue directly + try { + return queue.dequeueEventItem(100); + } catch (EventQueueClosedException e) { + // Queue closed, return null + return null; + } + } } \ No newline at end of file diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java new file mode 100644 index 000000000..a91575aaa --- /dev/null +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +public class EventQueueUtil { + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + public static void stop(MainEventBusProcessor processor) { + processor.stop(); + } +} diff --git a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties index d0692ca53..ea64096cb 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties +++ b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-1/src/main/resources/application.properties @@ -34,5 +34,7 @@ quarkus.messaging.kafka.health.timeout=5s # Enable debug logging quarkus.log.category."io.a2a.server.events".level=DEBUG quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG quarkus.log.category."io.a2a.extras.queuemanager.replicated".level=DEBUG +quarkus.log.category."io.a2a.extras.taskstore.database.jpa".level=DEBUG quarkus.log.category."io.a2a.client".level=DEBUG diff --git a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties index 0b647f3a5..d9a495e8f 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties +++ b/extras/queue-manager-replicated/tests-multi-instance/quarkus-app-2/src/main/resources/application.properties @@ -34,5 +34,7 @@ quarkus.messaging.kafka.health.timeout=5s # Enable debug logging quarkus.log.category."io.a2a.server.events".level=DEBUG quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG quarkus.log.category."io.a2a.extras.queuemanager.replicated".level=DEBUG +quarkus.log.category."io.a2a.extras.taskstore.database.jpa".level=DEBUG quarkus.log.category."io.a2a.client".level=DEBUG diff --git a/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java b/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java index 93093388d..b98f0e87d 100644 --- a/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java +++ b/extras/queue-manager-replicated/tests-multi-instance/tests/src/test/java/io/a2a/extras/queuemanager/replicated/tests/multiinstance/MultiInstanceReplicationTest.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; @@ -254,9 +255,11 @@ public void testMultiInstanceEventReplication() throws Exception { final String taskId = "replication-test-task-" + System.currentTimeMillis(); final String contextId = "replication-test-context"; - // Step 1: Send initial message NON-streaming to create task - Message initialMessage = Message.builder(A2A.toUserMessage("Initial test message")) - .taskId(taskId) + Throwable testFailure = null; + try { + // Step 1: Send initial message NON-streaming to create task + Message initialMessage = Message.builder(A2A.toUserMessage("Initial test message")) + .taskId(taskId) .contextId(contextId) .build(); @@ -308,11 +311,33 @@ public void testMultiInstanceEventReplication() throws Exception { AtomicReference app1Error = new AtomicReference<>(); AtomicReference app2Error = new AtomicReference<>(); + AtomicBoolean app1ReceivedInitialTask = new AtomicBoolean(false); + AtomicBoolean app2ReceivedInitialTask = new AtomicBoolean(false); // App1 subscriber BiConsumer app1Subscriber = (event, card) -> { + String eventDetail = event.getClass().getSimpleName(); + if (event instanceof io.a2a.client.TaskUpdateEvent tue) { + eventDetail += " [" + tue.getUpdateEvent().getClass().getSimpleName(); + if (tue.getUpdateEvent() instanceof io.a2a.spec.TaskStatusUpdateEvent statusEvent) { + eventDetail += ", state=" + (statusEvent.status() != null ? statusEvent.status().state() : "null"); + } + eventDetail += "]"; + } else if (event instanceof io.a2a.client.TaskEvent te) { + eventDetail += " [state=" + (te.getTask().status() != null ? te.getTask().status().state() : "null") + "]"; + } + System.out.println("APP1 received event: " + eventDetail); + + // Per A2A spec 3.1.6: Handle initial TaskEvent on resubscribe + if (!app1ReceivedInitialTask.get() && event instanceof io.a2a.client.TaskEvent) { + app1ReceivedInitialTask.set(true); + System.out.println("APP1 filtered initial TaskEvent"); + // Don't count initial TaskEvent toward expected artifact/status events + return; + } app1Events.add(event); - app1EventCount.incrementAndGet(); + int count = app1EventCount.incrementAndGet(); + System.out.println("APP1 event count now: " + count + ", event: " + eventDetail); }; Consumer app1ErrorHandler = error -> { @@ -323,8 +348,28 @@ public void testMultiInstanceEventReplication() throws Exception { // App2 subscriber BiConsumer app2Subscriber = (event, card) -> { + String eventDetail = event.getClass().getSimpleName(); + if (event instanceof io.a2a.client.TaskUpdateEvent tue) { + eventDetail += " [" + tue.getUpdateEvent().getClass().getSimpleName(); + if (tue.getUpdateEvent() instanceof io.a2a.spec.TaskStatusUpdateEvent statusEvent) { + eventDetail += ", state=" + (statusEvent.status() != null ? statusEvent.status().state() : "null"); + } + eventDetail += "]"; + } else if (event instanceof io.a2a.client.TaskEvent te) { + eventDetail += " [state=" + (te.getTask().status() != null ? te.getTask().status().state() : "null") + "]"; + } + System.out.println("APP2 received event: " + eventDetail); + + // Per A2A spec 3.1.6: Handle initial TaskEvent on resubscribe + if (!app2ReceivedInitialTask.get() && event instanceof io.a2a.client.TaskEvent) { + app2ReceivedInitialTask.set(true); + System.out.println("APP2 filtered initial TaskEvent"); + // Don't count initial TaskEvent toward expected artifact/status events + return; + } app2Events.add(event); - app2EventCount.incrementAndGet(); + int count = app2EventCount.incrementAndGet(); + System.out.println("APP2 event count now: " + count + ", event: " + eventDetail); }; Consumer app2ErrorHandler = error -> { @@ -409,9 +454,45 @@ public void testMultiInstanceEventReplication() throws Exception { throw new AssertionError("App2 subscriber error", app2Error.get()); } - // Verify both received at least 3 events (could be more due to initial state events) - assertTrue(app1Events.size() >= 3, "App1 should receive at least 3 events, got: " + app1Events.size()); - assertTrue(app2Events.size() >= 3, "App2 should receive at least 3 events, got: " + app2Events.size()); + // Verify both received at least 3 events (could be more due to initial state events) + assertTrue(app1Events.size() >= 3, "App1 should receive at least 3 events, got: " + app1Events.size()); + assertTrue(app2Events.size() >= 3, "App2 should receive at least 3 events, got: " + app2Events.size()); + } catch (Throwable t) { + testFailure = t; + throw t; + } finally { + // Output container logs if test failed + if (testFailure != null) { + System.err.println("\n========================================"); + System.err.println("TEST FAILED - Dumping container logs"); + System.err.println("========================================\n"); + + dumpContainerLogs("KAFKA", kafka, 100); + dumpContainerLogs("APP1", app1, 200); + dumpContainerLogs("APP2", app2, 200); + + System.err.println("\n========================================"); + System.err.println("END OF CONTAINER LOGS"); + System.err.println("========================================\n"); + } + } + } + + /** + * Dumps the last N lines of logs from a container to stderr. + */ + private void dumpContainerLogs(String containerName, org.testcontainers.containers.ContainerState container, int lastLines) { + System.err.println("\n--- " + containerName + " LOGS (last " + lastLines + " lines) ---"); + try { + String logs = container.getLogs(); + String[] lines = logs.split("\n"); + int start = Math.max(0, lines.length - lastLines); + for (int i = start; i < lines.length; i++) { + System.err.println(lines[i]); + } + } catch (Exception e) { + System.err.println("Failed to retrieve " + containerName + " logs: " + e.getMessage()); + } } /** diff --git a/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java b/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java index 3825f3bad..c4b23fb0d 100644 --- a/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java +++ b/extras/queue-manager-replicated/tests-single-instance/src/test/java/io/a2a/extras/queuemanager/replicated/tests/KafkaReplicationIntegrationTest.java @@ -222,9 +222,21 @@ public void testKafkaEventReceivedByA2AServer() throws Exception { AtomicReference receivedCompletedEvent = new AtomicReference<>(); AtomicBoolean wasUnexpectedEvent = new AtomicBoolean(false); AtomicReference errorRef = new AtomicReference<>(); + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); // Create consumer to handle resubscribed events BiConsumer consumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!receivedInitialTask.get()) { + if (event instanceof TaskEvent) { + receivedInitialTask.set(true); + return; + } else { + throw new AssertionError("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent taskUpdateEvent) { if (taskUpdateEvent.getUpdateEvent() instanceof TaskStatusUpdateEvent statusEvent) { if (statusEvent.status().state() == TaskState.COMPLETED) { diff --git a/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java b/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java index b65b71650..a4afa8588 100644 --- a/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java +++ b/extras/task-store-database-jpa/src/main/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStore.java @@ -65,19 +65,23 @@ void initConfig() { @Transactional @Override - public void save(Task task) { - LOGGER.debug("Saving task with ID: {}", task.id()); + public void save(Task task, boolean isReplicated) { + LOGGER.debug("Saving task with ID: {} (replicated: {})", task.id(), isReplicated); try { JpaTask jpaTask = JpaTask.createFromTask(task); em.merge(jpaTask); LOGGER.debug("Persisted/updated task with ID: {}", task.id()); - if (task.status() != null && task.status().state() != null && task.status().state().isFinal()) { + // Only fire TaskFinalizedEvent for locally-generated final states, NOT for replicated events + // This prevents feedback loops where receiving a replicated final task triggers another replication + if (!isReplicated && task.status() != null && task.status().state() != null && task.status().state().isFinal()) { // Fire CDI event if task reached final state // IMPORTANT: The event will be delivered AFTER transaction commits (AFTER_SUCCESS observers) - // This ensures the task's final state is durably stored before the QueueClosedEvent poison pill is sent - LOGGER.debug("Task {} is in final state, firing TaskFinalizedEvent", task.id()); - taskFinalizedEvent.fire(new TaskFinalizedEvent(task.id())); + // This ensures the task's final state is durably stored before the final task and poison pill are sent + LOGGER.debug("Task {} is in final state, firing TaskFinalizedEvent with full Task", task.id()); + taskFinalizedEvent.fire(new TaskFinalizedEvent(task.id(), task)); + } else if (isReplicated && task.status() != null && task.status().state() != null && task.status().state().isFinal()) { + LOGGER.debug("Task {} is in final state but from replication - NOT firing TaskFinalizedEvent (prevents feedback loop)", task.id()); } } catch (JsonProcessingException e) { LOGGER.error("Failed to serialize task with ID: {}", task.id(), e); diff --git a/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java b/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java index ea77f73c7..15c01a626 100644 --- a/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java +++ b/extras/task-store-database-jpa/src/test/java/io/a2a/extras/taskstore/database/jpa/JpaDatabaseTaskStoreTest.java @@ -54,7 +54,7 @@ public void testSaveAndRetrieveTask() { .build(); // Save the task - taskStore.save(task); + taskStore.save(task, false); // Retrieve the task Task retrieved = taskStore.get("test-task-1"); @@ -84,7 +84,7 @@ public void testSaveAndRetrieveTaskWithHistory() { .build(); // Save the task - taskStore.save(task); + taskStore.save(task, false); // Retrieve the task Task retrieved = taskStore.get("test-task-2"); @@ -108,7 +108,7 @@ public void testUpdateExistingTask() { .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Update the task Task updatedTask = Task.builder() @@ -117,7 +117,7 @@ public void testUpdateExistingTask() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(updatedTask); + taskStore.save(updatedTask, false); // Retrieve and verify the update Task retrieved = taskStore.get("test-task-3"); @@ -144,7 +144,7 @@ public void testDeleteTask() { .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Verify it exists assertNotNull(taskStore.get("test-task-4")); @@ -180,7 +180,7 @@ public void testTaskWithComplexMetadata() { .build(); // Save and retrieve - taskStore.save(task); + taskStore.save(task, false); Task retrieved = taskStore.get("test-task-5"); assertNotNull(retrieved); @@ -201,7 +201,7 @@ public void testIsTaskActiveForNonFinalTask() { .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Task should be active (not in final state) JpaDatabaseTaskStore jpaDatabaseTaskStore = (JpaDatabaseTaskStore) taskStore; @@ -220,7 +220,7 @@ public void testIsTaskActiveForFinalTaskWithinGracePeriod() { .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Update to final state Task finalTask = Task.builder() @@ -229,7 +229,7 @@ public void testIsTaskActiveForFinalTaskWithinGracePeriod() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(finalTask); + taskStore.save(finalTask, false); // Task should be active (within grace period - default 15 seconds) JpaDatabaseTaskStore jpaDatabaseTaskStore = (JpaDatabaseTaskStore) taskStore; @@ -248,7 +248,7 @@ public void testIsTaskActiveForFinalTaskBeyondGracePeriod() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Directly update the finalizedAt timestamp to 20 seconds in the past // (beyond the default 15-second grace period) @@ -322,9 +322,9 @@ public void testListTasksFilterByContextId() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(task1); - taskStore.save(task2); - taskStore.save(task3); + taskStore.save(task1, false); + taskStore.save(task2, false); + taskStore.save(task3, false); // List tasks for context-A ListTasksParams params = ListTasksParams.builder() @@ -361,9 +361,9 @@ public void testListTasksFilterByStatus() { .status(new TaskStatus(TaskState.COMPLETED)) .build(); - taskStore.save(task1); - taskStore.save(task2); - taskStore.save(task3); + taskStore.save(task1, false); + taskStore.save(task2, false); + taskStore.save(task3, false); // List only WORKING tasks in this context ListTasksParams params = ListTasksParams.builder() @@ -401,9 +401,9 @@ public void testListTasksCombinedFilters() { .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task1); - taskStore.save(task2); - taskStore.save(task3); + taskStore.save(task1, false); + taskStore.save(task2, false); + taskStore.save(task3, false); // List WORKING tasks in context-X ListTasksParams params = ListTasksParams.builder() @@ -432,7 +432,7 @@ public void testListTasksPagination() { .contextId("context-pagination") .status(new TaskStatus(TaskState.SUBMITTED, null, sameTimestamp)) .build(); - taskStore.save(task); + taskStore.save(task, false); } // First page: pageSize=2 @@ -488,7 +488,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(10))) .build(); - taskStore.save(task1); + taskStore.save(task1, false); // Task 2: 5 minutes ago, ID="task-diff-b" Task task2 = Task.builder() @@ -496,7 +496,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(5))) .build(); - taskStore.save(task2); + taskStore.save(task2, false); // Task 3: 5 minutes ago, ID="task-diff-c" (same timestamp as task2, tests ID tie-breaker) Task task3 = Task.builder() @@ -504,7 +504,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(5))) .build(); - taskStore.save(task3); + taskStore.save(task3, false); // Task 4: Now, ID="task-diff-d" Task task4 = Task.builder() @@ -512,7 +512,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now)) .build(); - taskStore.save(task4); + taskStore.save(task4, false); // Task 5: 1 minute ago, ID="task-diff-e" Task task5 = Task.builder() @@ -520,7 +520,7 @@ public void testListTasksPaginationWithDifferentTimestamps() { .contextId("context-diff-timestamps") .status(new TaskStatus(TaskState.WORKING, null, now.minusMinutes(1))) .build(); - taskStore.save(task5); + taskStore.save(task5, false); // Expected order (timestamp DESC, id ASC): // 1. task-diff-d (now) @@ -616,7 +616,7 @@ public void testListTasksHistoryLimiting() { .history(longHistory) .build(); - taskStore.save(task); + taskStore.save(task, false); // List with historyLength=3 (should keep only last 3 messages) - filter by unique context ListTasksParams params = ListTasksParams.builder() @@ -654,7 +654,7 @@ public void testListTasksArtifactInclusion() { .artifacts(artifacts) .build(); - taskStore.save(task); + taskStore.save(task, false); // List without artifacts (default) - filter by unique context ListTasksParams paramsWithoutArtifacts = ListTasksParams.builder() @@ -691,7 +691,7 @@ public void testListTasksDefaultPageSize() { .contextId("context-default-pagesize") .status(new TaskStatus(TaskState.SUBMITTED)) .build(); - taskStore.save(task); + taskStore.save(task, false); } // List without specifying pageSize (should use default of 50) @@ -715,7 +715,7 @@ public void testListTasksInvalidPageTokenFormat() { .contextId("context-invalid-token") .status(new TaskStatus(TaskState.WORKING)) .build(); - taskStore.save(task); + taskStore.save(task, false); // Test 1: Legacy ID-only pageToken should throw InvalidParamsError ListTasksParams params1 = ListTasksParams.builder() @@ -777,9 +777,9 @@ public void testListTasksOrderingById() { .build(); // Save in reverse order - taskStore.save(task3); - taskStore.save(task1); - taskStore.save(task2); + taskStore.save(task3, false); + taskStore.save(task1, false); + taskStore.save(task2, false); // List should return sorted by timestamp DESC (all same), then by ID ASC ListTasksParams params = ListTasksParams.builder() diff --git a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 18e18a2f1..cb5bdb25b 100644 --- a/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -13,7 +13,6 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; @@ -21,6 +20,7 @@ import com.google.gson.JsonSyntaxException; import io.a2a.common.A2AHeaders; +import io.a2a.server.util.sse.SseFormatter; import io.a2a.grpc.utils.JSONRPCUtils; import io.a2a.jsonrpc.common.json.IdJsonMappingException; import io.a2a.jsonrpc.common.json.InvalidParamsJsonMappingException; @@ -65,7 +65,6 @@ import io.a2a.transport.jsonrpc.handler.JSONRPCHandler; import io.quarkus.security.Authenticated; import io.quarkus.vertx.web.Body; -import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; import io.smallrye.mutiny.Multi; import io.vertx.core.AsyncResult; @@ -74,6 +73,8 @@ import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerResponse; import io.vertx.ext.web.RoutingContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @Singleton public class A2AServerRoutes { @@ -135,8 +136,12 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { } else if (streaming) { final Multi> finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - finalStreamingResponse.map(i -> (Object) i), rc); + // Convert Multi to Multi with SSE formatting + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = finalStreamingResponse + .map(response -> SseFormatter.formatResponseAsSSE(response, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } else { @@ -295,34 +300,30 @@ private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse + * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final Logger logger = LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else { - subscription.request(1); - } - } - - public static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.Subscription upstream; @Override @@ -330,6 +331,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.info("SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -338,54 +346,50 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + response.setChunked(true); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + upstream.cancel(); + rc.fail(ar.cause()); + } else { + upstream.request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + upstream.cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - public static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - String data = serializeResponse((A2AResponse) ev.data()); - return Buffer.buffer(e + "data: " + data + "\nid: " + id + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } - String data = serializeResponse((A2AResponse) o); - return Buffer.buffer("data: " + data + "\nid: " + count.getAndIncrement() + "\n\n"); - } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + response.end(); } - } - response.end(); + }); } } } diff --git a/reference/jsonrpc/src/test/resources/application.properties b/reference/jsonrpc/src/test/resources/application.properties index 7b9cea9cc..e612925d4 100644 --- a/reference/jsonrpc/src/test/resources/application.properties +++ b/reference/jsonrpc/src/test/resources/application.properties @@ -1 +1,6 @@ quarkus.arc.selected-alternatives=io.a2a.server.apps.common.TestHttpClient + +# Debug logging for event processing and request handling +quarkus.log.category."io.a2a.server.events".level=DEBUG +quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG +quarkus.log.category."io.a2a.server.tasks".level=DEBUG diff --git a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java index 46d0d38e6..7a50f0afb 100644 --- a/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java +++ b/reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java @@ -15,7 +15,8 @@ import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; + +import io.a2a.server.util.sse.SseFormatter; import jakarta.annotation.security.PermitAll; import jakarta.enterprise.inject.Instance; @@ -38,7 +39,6 @@ import io.a2a.util.Utils; import io.quarkus.security.Authenticated; import io.quarkus.vertx.web.Body; -import io.quarkus.vertx.web.ReactiveRoutes; import io.quarkus.vertx.web.Route; import io.smallrye.mutiny.Multi; import io.vertx.core.AsyncResult; @@ -110,10 +110,14 @@ public void sendMessageStreaming(@Body String body, RoutingContext rc) { if (error != null) { sendResponse(rc, error); } else if (streamingResponse != null) { - Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - events.map(i -> (Object) i), rc); + // Convert Flow.Publisher (JSON) to Multi (SSE-formatted) + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher()) + .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } } @@ -243,10 +247,14 @@ public void subscribeToTask(RoutingContext rc) { if (error != null) { sendResponse(rc, error); } else if (streamingResponse != null) { - Multi events = Multi.createFrom().publisher(streamingResponse.getPublisher()); + final HTTPRestStreamingResponse finalStreamingResponse = streamingResponse; executor.execute(() -> { - MultiSseSupport.subscribeObject( - events.map(i -> (Object) i), rc); + // Convert Flow.Publisher (JSON) to Multi (SSE-formatted) + AtomicLong eventIdCounter = new AtomicLong(0); + Multi sseEvents = Multi.createFrom().publisher(finalStreamingResponse.getPublisher()) + .map(json -> SseFormatter.formatJsonAsSSE(json, eventIdCounter.getAndIncrement())); + // Write SSE-formatted strings to HTTP response + MultiSseSupport.writeSseStrings(sseEvents, rc, context); }); } } @@ -450,34 +458,30 @@ public String getUsername() { } } - // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API + /** + * Simplified SSE support for Vert.x/Quarkus. + *

+ * This class only handles HTTP-specific concerns (writing to response, backpressure, disconnect). + * SSE formatting and JSON serialization are handled by {@link SseFormatter}. + */ private static class MultiSseSupport { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(MultiSseSupport.class); private MultiSseSupport() { // Avoid direct instantiation. } - private static void initialize(HttpServerResponse response) { - if (response.bytesWritten() == 0) { - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - response.setChunked(true); - } - } - - private static void onWriteDone(Flow.@Nullable Subscription subscription, AsyncResult ar, RoutingContext rc) { - if (ar.failed()) { - rc.fail(ar.cause()); - } else if (subscription != null) { - subscription.request(1); - } - } - - private static void write(Multi multi, RoutingContext rc) { + /** + * Write SSE-formatted strings to HTTP response. + * + * @param sseStrings Multi stream of SSE-formatted strings (from SseFormatter) + * @param rc Vert.x routing context + * @param context A2A server call context (for EventConsumer cancellation) + */ + public static void writeSseStrings(Multi sseStrings, RoutingContext rc, ServerCallContext context) { HttpServerResponse response = rc.response(); - multi.subscribe().withSubscriber(new Flow.Subscriber() { + + sseStrings.subscribe().withSubscriber(new Flow.Subscriber() { Flow.@Nullable Subscription upstream; @Override @@ -485,6 +489,13 @@ public void onSubscribe(Flow.Subscription subscription) { this.upstream = subscription; this.upstream.request(1); + // Detect client disconnect and call EventConsumer.cancel() directly + response.closeHandler(v -> { + logger.debug("REST SSE connection closed by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + }); + // Notify tests that we are subscribed Runnable runnable = streamingMultiSseSupportSubscribedRunnable; if (runnable != null) { @@ -493,53 +504,64 @@ public void onSubscribe(Flow.Subscription subscription) { } @Override - public void onNext(Buffer item) { - initialize(response); - response.write(item, new Handler>() { + public void onNext(String sseEvent) { + // Set SSE headers on first event + if (response.bytesWritten() == 0) { + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } + // Additional SSE headers to prevent buffering + headers.set("Cache-Control", "no-cache"); + headers.set("X-Accel-Buffering", "no"); // Disable nginx buffering + response.setChunked(true); + + // CRITICAL: Disable write queue max size to prevent buffering + // Vert.x buffers writes by default - we need immediate flushing for SSE + response.setWriteQueueMaxSize(1); // Force immediate flush + + // Send initial SSE comment to kickstart the stream + // This forces Vert.x to send headers and start the stream immediately + response.write(": SSE stream started\n\n"); + } + + // Write SSE-formatted string to response + response.write(Buffer.buffer(sseEvent), new Handler>() { @Override public void handle(AsyncResult ar) { - onWriteDone(upstream, ar, rc); + if (ar.failed()) { + // Client disconnected or write failed - cancel upstream to stop EventConsumer + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); + rc.fail(ar.cause()); + } else { + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).request(1); + } } }); } @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + // NullAway: upstream is guaranteed non-null after onSubscribe + java.util.Objects.requireNonNull(upstream).cancel(); rc.fail(throwable); } @Override public void onComplete() { - endOfStream(response); - } - }); - } - - private static void subscribeObject(Multi multi, RoutingContext rc) { - AtomicLong count = new AtomicLong(); - write(multi.map(new Function() { - @Override - public Buffer apply(Object o) { - if (o instanceof ReactiveRoutes.ServerSentEvent) { - ReactiveRoutes.ServerSentEvent ev = (ReactiveRoutes.ServerSentEvent) o; - long id = ev.id() != -1 ? ev.id() : count.getAndIncrement(); - String e = ev.event() == null ? "" : "event: " + ev.event() + "\n"; - return Buffer.buffer(e + "data: " + ev.data() + "\nid: " + id + "\n\n"); - } else { - return Buffer.buffer("data: " + o + "\nid: " + count.getAndIncrement() + "\n\n"); + if (response.bytesWritten() == 0) { + // No events written - still set SSE content type + MultiMap headers = response.headers(); + if (headers.get(CONTENT_TYPE) == null) { + headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); + } } + response.end(); } - }), rc); - } - - private static void endOfStream(HttpServerResponse response) { - if (response.bytesWritten() == 0) { // No item - MultiMap headers = response.headers(); - if (headers.get(CONTENT_TYPE) == null) { - headers.set(CONTENT_TYPE, SERVER_SENT_EVENTS); - } - } - response.end(); + }); } } diff --git a/server-common/src/main/java/io/a2a/server/ServerCallContext.java b/server-common/src/main/java/io/a2a/server/ServerCallContext.java index ba5c20b95..c12c60c21 100644 --- a/server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -16,6 +16,7 @@ public class ServerCallContext { private final Set requestedExtensions; private final Set activatedExtensions; private final @Nullable String requestedProtocolVersion; + private volatile @Nullable Runnable eventConsumerCancelCallback; public ServerCallContext(User user, Map state, Set requestedExtensions) { this(user, state, requestedExtensions, null); @@ -64,4 +65,64 @@ public boolean isExtensionRequested(String extensionUri) { public @Nullable String getRequestedProtocolVersion() { return requestedProtocolVersion; } + + /** + * Sets the callback to be invoked when the client disconnects or the call is cancelled. + *

+ * This callback is typically used to stop the EventConsumer polling loop when a client + * disconnects from a streaming endpoint. The callback is invoked by transport layers + * (JSON-RPC over HTTP/SSE, REST over HTTP/SSE, gRPC streaming) when they detect that + * the client has closed the connection. + *

+ *

+ * Thread Safety: The callback may be invoked from any thread, depending + * on the transport implementation. Implementations should be thread-safe. + *

+ * Example Usage: + *
{@code
+     * EventConsumer consumer = new EventConsumer(queue);
+     * context.setEventConsumerCancelCallback(consumer::cancel);
+     * }
+ * + * @param callback the callback to invoke on client disconnect, or null to clear any existing callback + * @see #invokeEventConsumerCancelCallback() + */ + public void setEventConsumerCancelCallback(@Nullable Runnable callback) { + this.eventConsumerCancelCallback = callback; + } + + /** + * Invokes the EventConsumer cancel callback if one has been set. + *

+ * This method is called by transport layers when a client disconnects or cancels a + * streaming request. It triggers the callback registered via + * {@link #setEventConsumerCancelCallback(Runnable)}, which typically stops the + * EventConsumer polling loop. + *

+ *

+ * Transport-Specific Behavior: + *

+ *
    + *
  • JSON-RPC/REST over HTTP/SSE: Called from Vert.x + * {@code HttpServerResponse.closeHandler()} when the SSE connection is closed
  • + *
  • gRPC streaming: Called from gRPC + * {@code Context.CancellationListener.cancelled()} when the call is cancelled
  • + *
+ *

+ * Thread Safety: This method is thread-safe. The callback is stored + * in a volatile field and null-checked before invocation to prevent race conditions. + *

+ *

+ * If no callback has been set, this method does nothing (no-op). + *

+ * + * @see #setEventConsumerCancelCallback(Runnable) + * @see io.a2a.server.events.EventConsumer#cancel() + */ + public void invokeEventConsumerCancelCallback() { + Runnable callback = this.eventConsumerCancelCallback; + if (callback != null) { + callback.run(); + } + } } diff --git a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java index d4fe5b395..7c7b28452 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventConsumer.java +++ b/server-common/src/main/java/io/a2a/server/events/EventConsumer.java @@ -19,10 +19,17 @@ public class EventConsumer { private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumer.class); private final EventQueue queue; private volatile @Nullable Throwable error; + private volatile boolean cancelled = false; + private volatile boolean agentCompleted = false; + private volatile int pollTimeoutsAfterAgentCompleted = 0; private static final String ERROR_MSG = "Agent did not return any response"; private static final int NO_WAIT = -1; private static final int QUEUE_WAIT_MILLISECONDS = 500; + // In replicated scenarios, events can arrive hundreds of milliseconds after local agent completes + // Grace period allows Kafka replication to deliver late-arriving events + // 3 timeouts * 500ms = 1500ms grace period for replication delays + private static final int MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED = 3; public EventConsumer(EventQueue queue) { this.queue = queue; @@ -45,6 +52,14 @@ public Flow.Publisher consumeAll() { boolean completed = false; try { while (true) { + // Check if cancelled by client disconnect + if (cancelled) { + LOGGER.debug("EventConsumer detected cancellation, exiting polling loop for queue {}", System.identityHashCode(queue)); + completed = true; + tube.complete(); + return; + } + if (error != null) { completed = true; tube.fail(error); @@ -60,13 +75,49 @@ public Flow.Publisher consumeAll() { EventQueueItem item; Event event; try { + LOGGER.debug("EventConsumer polling queue {} (error={}, agentCompleted={})", + System.identityHashCode(queue), error, agentCompleted); item = queue.dequeueEventItem(QUEUE_WAIT_MILLISECONDS); if (item == null) { + int queueSize = queue.size(); + LOGGER.debug("EventConsumer poll timeout (null item), agentCompleted={}, queue.size()={}, timeoutCount={}", + agentCompleted, queueSize, pollTimeoutsAfterAgentCompleted); + // If agent completed, a poll timeout means no more events are coming + // MainEventBusProcessor has 500ms to distribute events from MainEventBus + // If we timeout with agentCompleted=true, all events have been distributed + // + // IMPORTANT: In replicated scenarios, remote events may arrive AFTER local agent completes! + // Use grace period to allow for Kafka replication delays (can be 400-500ms) + if (agentCompleted && queueSize == 0) { + pollTimeoutsAfterAgentCompleted++; + if (pollTimeoutsAfterAgentCompleted >= MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED) { + LOGGER.debug("Agent completed with {} consecutive poll timeouts and empty queue, closing for graceful completion (queue={})", + pollTimeoutsAfterAgentCompleted, System.identityHashCode(queue)); + queue.close(); + completed = true; + tube.complete(); + return; + } else { + LOGGER.debug("Agent completed but grace period active ({}/{} timeouts), continuing to poll (queue={})", + pollTimeoutsAfterAgentCompleted, MAX_POLL_TIMEOUTS_AFTER_AGENT_COMPLETED, System.identityHashCode(queue)); + } + } else if (agentCompleted && queueSize > 0) { + LOGGER.debug("Agent completed but queue has {} pending events, resetting timeout counter and continuing to poll (queue={})", + queueSize, System.identityHashCode(queue)); + pollTimeoutsAfterAgentCompleted = 0; // Reset counter when events arrive + } continue; } + // Event received - reset timeout counter + pollTimeoutsAfterAgentCompleted = 0; event = item.getEvent(); + LOGGER.debug("EventConsumer received event: {} (queue={})", + event.getClass().getSimpleName(), System.identityHashCode(queue)); + // Defensive logging for error handling if (event instanceof Throwable thr) { + LOGGER.debug("EventConsumer detected Throwable event: {} - triggering tube.fail()", + thr.getClass().getSimpleName()); tube.fail(thr); return; } @@ -90,14 +141,30 @@ public Flow.Publisher consumeAll() { // Only send event if it's not a QueueClosedEvent // QueueClosedEvent is an internal coordination event used for replication // and should not be exposed to API consumers + boolean isFinalSent = false; if (!(event instanceof QueueClosedEvent)) { tube.send(item); + isFinalSent = isFinalEvent; } if (isFinalEvent) { LOGGER.debug("Final or interrupted event detected, closing queue and breaking loop for queue {}", System.identityHashCode(queue)); queue.close(); LOGGER.debug("Queue closed, breaking loop for queue {}", System.identityHashCode(queue)); + + // CRITICAL: Allow tube buffer to flush before calling tube.complete() + // tube.send() buffers events asynchronously. If we call tube.complete() immediately, + // the stream-end signal can reach the client BEFORE the buffered final event, + // causing the client to close the connection and never receive the final event. + // This is especially important in replicated scenarios where events arrive via Kafka + // and timing is less deterministic. A small delay ensures the buffer flushes. + if (isFinalSent) { + try { + Thread.sleep(50); // 50ms to allow SSE buffer flush + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } break; } } catch (EventQueueClosedException e) { @@ -138,12 +205,25 @@ private boolean isStreamTerminatingTask(Task task) { public EnhancedRunnable.DoneCallback createAgentRunnableDoneCallback() { return agentRunnable -> { + LOGGER.debug("EventConsumer: Agent done callback invoked (hasError={}, queue={})", + agentRunnable.getError() != null, System.identityHashCode(queue)); if (agentRunnable.getError() != null) { error = agentRunnable.getError(); + LOGGER.debug("EventConsumer: Set error field from agent callback"); + } else { + agentCompleted = true; + LOGGER.debug("EventConsumer: Agent completed successfully, set agentCompleted=true, will close queue after draining"); } }; } + public void cancel() { + // Set cancellation flag to stop polling loop + // Called when client disconnects without completing stream + LOGGER.debug("EventConsumer cancelled (client disconnect), stopping polling for queue {}", System.identityHashCode(queue)); + cancelled = true; + } + public void close() { // Close the queue to stop the polling loop in consumeAll() // This will cause EventQueueClosedException and exit the while(true) loop diff --git a/server-common/src/main/java/io/a2a/server/events/EventQueue.java b/server-common/src/main/java/io/a2a/server/events/EventQueue.java index a08f63084..99f8bc2dc 100644 --- a/server-common/src/main/java/io/a2a/server/events/EventQueue.java +++ b/server-common/src/main/java/io/a2a/server/events/EventQueue.java @@ -1,6 +1,7 @@ package io.a2a.server.events; import java.util.List; +import java.util.Objects; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; @@ -11,6 +12,8 @@ import io.a2a.server.tasks.TaskStateProvider; import io.a2a.spec.Event; +import io.a2a.spec.Task; +import io.a2a.spec.TaskStatusUpdateEvent; import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,7 +26,7 @@ * and hierarchical queue structures via MainQueue and ChildQueue implementations. *

*

- * Use {@link #builder()} to create configured instances or extend MainQueue/ChildQueue directly. + * Use {@link #builder(MainEventBus)} to create configured instances or extend MainQueue/ChildQueue directly. *

*/ public abstract class EventQueue implements AutoCloseable { @@ -36,14 +39,6 @@ public abstract class EventQueue implements AutoCloseable { public static final int DEFAULT_QUEUE_SIZE = 1000; private final int queueSize; - /** - * Internal blocking queue for storing event queue items. - */ - protected final BlockingQueue queue = new LinkedBlockingDeque<>(); - /** - * Semaphore for backpressure control, limiting the number of pending events. - */ - protected final Semaphore semaphore; private volatile boolean closed = false; /** @@ -64,7 +59,6 @@ protected EventQueue(int queueSize) { throw new IllegalArgumentException("Queue size must be greater than 0"); } this.queueSize = queueSize; - this.semaphore = new Semaphore(queueSize, true); LOGGER.trace("Creating {} with queue size: {}", this, queueSize); } @@ -78,8 +72,8 @@ protected EventQueue(EventQueue parent) { LOGGER.trace("Creating {}, parent: {}", this, parent); } - static EventQueueBuilder builder() { - return new EventQueueBuilder(); + static EventQueueBuilder builder(MainEventBus mainEventBus) { + return new EventQueueBuilder().mainEventBus(mainEventBus); } /** @@ -95,6 +89,7 @@ public static class EventQueueBuilder { private @Nullable String taskId; private List onCloseCallbacks = new java.util.ArrayList<>(); private @Nullable TaskStateProvider taskStateProvider; + private @Nullable MainEventBus mainEventBus; /** * Sets the maximum queue size. @@ -153,17 +148,31 @@ public EventQueueBuilder taskStateProvider(TaskStateProvider taskStateProvider) return this; } + /** + * Sets the main event bus + * + * @param mainEventBus the main event bus + * @return this builder + */ + public EventQueueBuilder mainEventBus(MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; + return this; + } + /** * Builds and returns the configured EventQueue. * * @return a new MainQueue instance */ public EventQueue build() { - if (hook != null || !onCloseCallbacks.isEmpty() || taskStateProvider != null) { - return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider); - } else { - return new MainQueue(queueSize); + // MainEventBus is REQUIRED - enforce single architectural path + if (mainEventBus == null) { + throw new IllegalStateException("MainEventBus is required for EventQueue creation"); } + if (taskId == null) { + throw new IllegalStateException("taskId is required for EventQueue creation"); + } + return new MainQueue(queueSize, hook, taskId, onCloseCallbacks, taskStateProvider, mainEventBus); } } @@ -209,21 +218,39 @@ public void enqueueEvent(Event event) { * @param item the event queue item to enqueue * @throws RuntimeException if interrupted while waiting to acquire the semaphore */ - public void enqueueItem(EventQueueItem item) { - Event event = item.getEvent(); - if (closed) { - LOGGER.warn("Queue is closed. Event will not be enqueued. {} {}", this, event); - return; - } - // Call toString() since for errors we don't really want the full stacktrace - try { - semaphore.acquire(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); - } - queue.add(item); - LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + public abstract void enqueueItem(EventQueueItem item); + + /** + * Enqueues an event directly to this specific queue only, bypassing the MainEventBus. + *

+ * This method is used for enqueuing already-persisted events (e.g., current task state + * on resubscribe) that should only be sent to this specific subscriber, not distributed + * to all children or sent through MainEventBusProcessor. + *

+ *

+ * Default implementation throws UnsupportedOperationException. Only ChildQueue supports this. + *

+ * + * @param item the event queue item to enqueue directly + * @throws UnsupportedOperationException if called on MainQueue or other queue types + */ + public void enqueueLocalOnly(EventQueueItem item) { + throw new UnsupportedOperationException( + "enqueueLocalOnly is only supported on ChildQueue for resubscribe scenarios"); + } + + /** + * Enqueues an event directly to this specific queue only, bypassing the MainEventBus. + *

+ * Convenience method that wraps the event in a LocalEventQueueItem before calling + * {@link #enqueueLocalOnly(EventQueueItem)}. + *

+ * + * @param event the event to enqueue directly + * @throws UnsupportedOperationException if called on MainQueue or other queue types + */ + public void enqueueEventLocalOnly(Event event) { + enqueueLocalOnly(new LocalEventQueueItem(event)); } /** @@ -244,48 +271,17 @@ public void enqueueItem(EventQueueItem item) { * This method returns the full EventQueueItem wrapper, allowing callers to check * metadata like whether the event is replicated via {@link EventQueueItem#isReplicated()}. *

+ *

+ * Note: MainQueue does not support dequeue operations - only ChildQueues can be consumed. + *

* * @param waitMilliSeconds the maximum time to wait in milliseconds * @return the EventQueueItem, or null if timeout occurs * @throws EventQueueClosedException if the queue is closed and empty + * @throws UnsupportedOperationException if called on MainQueue */ - public @Nullable EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { - if (closed && queue.isEmpty()) { - LOGGER.debug("Queue is closed, and empty. Sending termination message. {}", this); - throw new EventQueueClosedException(); - } - try { - if (waitMilliSeconds <= 0) { - EventQueueItem item = queue.poll(); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } - return item; - } - try { - LOGGER.trace("Polling queue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); - EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); - if (item != null) { - Event event = item.getEvent(); - // Call toString() since for errors we don't really want the full stacktrace - LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); - semaphore.release(); - } else { - LOGGER.trace("Dequeue timeout (null) from queue {}", System.identityHashCode(this)); - } - return item; - } catch (InterruptedException e) { - LOGGER.debug("Interrupted dequeue (waiting) {}", this); - Thread.currentThread().interrupt(); - return null; - } - } finally { - signalQueuePollerStarted(); - } - } + @Nullable + public abstract EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException; /** * Placeholder method for task completion notification. @@ -295,6 +291,18 @@ public void taskDone() { // TODO Not sure if needed yet. BlockingQueue.poll()/.take() remove the events. } + /** + * Returns the current size of the queue. + *

+ * For MainQueue: returns the number of events in-flight (in MainEventBus queue + currently being processed). + * This reflects actual capacity usage tracked by the semaphore. + * For ChildQueue: returns the size of the local consumption queue. + *

+ * + * @return the number of events currently in the queue + */ + public abstract int size(); + /** * Closes this event queue gracefully, allowing pending events to be consumed. */ @@ -348,72 +356,71 @@ protected void doClose(boolean immediate) { LOGGER.debug("Closing {} (immediate={})", this, immediate); closed = true; } - - if (immediate) { - // Immediate close: clear pending events - queue.clear(); - LOGGER.debug("Cleared queue for immediate close: {}", this); - } - // For graceful close, let the queue drain naturally through normal consumption + // Subclasses handle immediate close logic (e.g., ChildQueue clears its local queue) } static class MainQueue extends EventQueue { private final List children = new CopyOnWriteArrayList<>(); + protected final Semaphore semaphore; private final CountDownLatch pollingStartedLatch = new CountDownLatch(1); private final AtomicBoolean pollingStarted = new AtomicBoolean(false); private final @Nullable EventEnqueueHook enqueueHook; - private final @Nullable String taskId; + private final String taskId; private final List onCloseCallbacks; private final @Nullable TaskStateProvider taskStateProvider; - - MainQueue() { - super(); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize) { - super(queueSize); - this.enqueueHook = null; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(EventEnqueueHook hook) { - super(); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, EventEnqueueHook hook) { - super(queueSize); - this.enqueueHook = hook; - this.taskId = null; - this.onCloseCallbacks = List.of(); - this.taskStateProvider = null; - } - - MainQueue(int queueSize, @Nullable EventEnqueueHook hook, @Nullable String taskId, List onCloseCallbacks, @Nullable TaskStateProvider taskStateProvider) { + private final MainEventBus mainEventBus; + + MainQueue(int queueSize, + @Nullable EventEnqueueHook hook, + String taskId, + List onCloseCallbacks, + @Nullable TaskStateProvider taskStateProvider, + @Nullable MainEventBus mainEventBus) { super(queueSize); + this.semaphore = new Semaphore(queueSize, true); this.enqueueHook = hook; this.taskId = taskId; this.onCloseCallbacks = List.copyOf(onCloseCallbacks); // Defensive copy this.taskStateProvider = taskStateProvider; - LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks and TaskStateProvider: {}", + this.mainEventBus = Objects.requireNonNull(mainEventBus, "MainEventBus is required"); + LOGGER.debug("Created MainQueue for task {} with {} onClose callbacks, TaskStateProvider: {}, MainEventBus configured", taskId, onCloseCallbacks.size(), taskStateProvider != null); } + public EventQueue tap() { ChildQueue child = new ChildQueue(this); children.add(child); return child; } + /** + * Returns the current number of child queues. + * Useful for debugging and logging event distribution. + */ + public int getChildCount() { + return children.size(); + } + + /** + * Returns the enqueue hook for replication (package-protected for MainEventBusProcessor). + */ + @Nullable EventEnqueueHook getEnqueueHook() { + return enqueueHook; + } + + @Override + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + throw new UnsupportedOperationException("MainQueue cannot be consumed directly - use tap() to create a ChildQueue for consumption"); + } + + @Override + public int size() { + // Return total in-flight events (in MainEventBus + being processed) + // This aligns with semaphore's capacity tracking + return getQueueSize() - semaphore.availablePermits(); + } + @Override public void enqueueItem(EventQueueItem item) { // MainQueue must accept events even when closed to support: @@ -424,6 +431,15 @@ public void enqueueItem(EventQueueItem item) { // We bypass the parent's closed check and enqueue directly Event event = item.getEvent(); + // Check if this is a final event BEFORE submitting to MainEventBus + // If it is, notify all children to expect it (so they wait for MainEventBusProcessor) + if (isFinalEvent(event)) { + LOGGER.debug("Final event detected, notifying {} children to expect it", children.size()); + for (ChildQueue child : children) { + child.expectFinalEvent(); + } + } + // Acquire semaphore for backpressure try { semaphore.acquire(); @@ -432,17 +448,27 @@ public void enqueueItem(EventQueueItem item) { throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e); } - // Add to this MainQueue's internal queue - queue.add(item); LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); - // Distribute to all ChildQueues (they will receive the event even if MainQueue is closed) - children.forEach(eq -> eq.internalEnqueueItem(item)); + // Submit to MainEventBus for centralized persistence + distribution + // MainEventBus is guaranteed non-null by constructor requirement + // Note: Replication now happens in MainEventBusProcessor AFTER persistence - // Trigger replication hook if configured - if (enqueueHook != null) { - enqueueHook.onEnqueue(item); + // Submit event to MainEventBus with our taskId + mainEventBus.submit(taskId, this, item); + } + + /** + * Checks if an event represents a final task state. + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); } + return false; } @Override @@ -465,20 +491,15 @@ public void signalQueuePollerStarted() { void childClosing(ChildQueue child, boolean immediate) { children.remove(child); // Remove the closing child - // Close immediately if requested - if (immediate) { - LOGGER.debug("MainQueue closing immediately (immediate=true)"); - this.doClose(immediate); - return; - } - // If there are still children, keep queue open if (!children.isEmpty()) { LOGGER.debug("MainQueue staying open: {} children remaining", children.size()); return; } - // No children left - check if task is finalized before auto-closing + // No children left - check if task is finalized before closing + // IMPORTANT: This check must happen BEFORE the immediate flag check + // to prevent closing queues for non-final tasks (fire-and-forget, resubscription support) if (taskStateProvider != null && taskId != null) { boolean isFinalized = taskStateProvider.isTaskFinalized(taskId); if (!isFinalized) { @@ -493,6 +514,36 @@ void childClosing(ChildQueue child, boolean immediate) { this.doClose(immediate); } + /** + * Distribute event to all ChildQueues. + * Called by MainEventBusProcessor after TaskStore persistence. + */ + void distributeToChildren(EventQueueItem item) { + int childCount = children.size(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Distributing event {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + children.forEach(child -> { + LOGGER.debug("MainQueue[{}]: Enqueueing event {} to child queue", + taskId, item.getEvent().getClass().getSimpleName()); + child.internalEnqueueItem(item); + }); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("MainQueue[{}]: Completed distribution of {} to {} children", + taskId, item.getEvent().getClass().getSimpleName(), childCount); + } + } + + /** + * Release the semaphore after event processing is complete. + * Called by MainEventBusProcessor in finally block to ensure release even on exceptions. + * Balances the acquire() in enqueueEvent() - protects MainEventBus throughput. + */ + void releaseSemaphore() { + semaphore.release(); + } + /** * Get the count of active child queues. * Used for testing to verify reference counting mechanism. @@ -539,10 +590,17 @@ public void close(boolean immediate) { public void close(boolean immediate, boolean notifyParent) { throw new UnsupportedOperationException("MainQueue does not support notifyParent parameter - use close(boolean) instead"); } + + String getTaskId() { + return taskId; + } } static class ChildQueue extends EventQueue { private final MainQueue parent; + private final BlockingQueue queue = new LinkedBlockingDeque<>(); + private volatile boolean immediateClose = false; + private volatile boolean awaitingFinalEvent = false; public ChildQueue(MainQueue parent) { this.parent = parent; @@ -553,8 +611,94 @@ public void enqueueEvent(Event event) { parent.enqueueEvent(event); } + @Override + public void enqueueItem(EventQueueItem item) { + // ChildQueue delegates writes to parent MainQueue + parent.enqueueItem(item); + } + private void internalEnqueueItem(EventQueueItem item) { - super.enqueueItem(item); + // Internal method called by MainEventBusProcessor to add to local queue + // Note: Semaphore is managed by parent MainQueue (acquire/release), not ChildQueue + Event event = item.getEvent(); + // For graceful close: still accept events so they can be drained by EventConsumer + // For immediate close: reject events to stop distribution quickly + if (isClosed() && immediateClose) { + LOGGER.warn("ChildQueue is immediately closed. Event will not be enqueued. {} {}", this, event); + return; + } + if (!queue.offer(item)) { + LOGGER.warn("ChildQueue {} is full. Closing immediately.", this); + close(true); // immediate close + } else { + LOGGER.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this); + + // If we were awaiting a final event and this is it, clear the flag + if (awaitingFinalEvent && isFinalEvent(event)) { + awaitingFinalEvent = false; + LOGGER.debug("ChildQueue {} received awaited final event", System.identityHashCode(this)); + } + } + } + + /** + * Checks if an event represents a final task state. + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); + } + return false; + } + + @Override + public void enqueueLocalOnly(EventQueueItem item) { + internalEnqueueItem(item); + } + + @Override + @Nullable + public EventQueueItem dequeueEventItem(int waitMilliSeconds) throws EventQueueClosedException { + // For immediate close: exit immediately even if queue is not empty (race with MainEventBusProcessor) + // For graceful close: only exit when queue is empty (wait for all events to be consumed) + // BUT: if awaiting final event, keep polling even if closed and empty + if (isClosed() && (queue.isEmpty() || immediateClose) && !awaitingFinalEvent) { + LOGGER.debug("ChildQueue is closed{}, sending termination message. {} (queueSize={})", + immediateClose ? " (immediate)" : " and empty", + this, + queue.size()); + throw new EventQueueClosedException(); + } + try { + if (waitMilliSeconds <= 0) { + EventQueueItem item = queue.poll(); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event); + } + return item; + } + try { + LOGGER.trace("Polling ChildQueue {} (wait={}ms)", System.identityHashCode(this), waitMilliSeconds); + EventQueueItem item = queue.poll(waitMilliSeconds, TimeUnit.MILLISECONDS); + if (item != null) { + Event event = item.getEvent(); + LOGGER.debug("Dequeued event item (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event); + } else { + LOGGER.trace("Dequeue timeout (null) from ChildQueue {}", System.identityHashCode(this)); + } + return item; + } catch (InterruptedException e) { + LOGGER.debug("Interrupted dequeue (waiting) {}", this); + Thread.currentThread().interrupt(); + return null; + } + } finally { + signalQueuePollerStarted(); + } } @Override @@ -562,6 +706,12 @@ public EventQueue tap() { throw new IllegalStateException("Can only tap the main queue"); } + @Override + public int size() { + // Return size of local consumption queue + return queue.size(); + } + @Override public void awaitQueuePollerStart() throws InterruptedException { parent.awaitQueuePollerStart(); @@ -572,6 +722,29 @@ public void signalQueuePollerStarted() { parent.signalQueuePollerStarted(); } + @Override + protected void doClose(boolean immediate) { + super.doClose(immediate); // Sets closed flag + if (immediate) { + // Immediate close: clear pending events from local queue + this.immediateClose = true; + int clearedCount = queue.size(); + queue.clear(); + LOGGER.debug("Cleared {} events from ChildQueue for immediate close: {}", clearedCount, this); + } + // For graceful close, let the queue drain naturally through normal consumption + } + + /** + * Notifies this ChildQueue to expect a final event. + * Called by MainQueue when it enqueues a final event, BEFORE submitting to MainEventBus. + * This ensures the ChildQueue keeps polling until the final event arrives (after MainEventBusProcessor). + */ + void expectFinalEvent() { + awaitingFinalEvent = true; + LOGGER.debug("ChildQueue {} now awaiting final event", System.identityHashCode(this)); + } + @Override public void close() { close(false); diff --git a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java index e5a17e0e7..53a089e4c 100644 --- a/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java @@ -34,16 +34,20 @@ protected InMemoryQueueManager() { this.taskStateProvider = null; } + MainEventBus mainEventBus; + @Inject - public InMemoryQueueManager(TaskStateProvider taskStateProvider) { + public InMemoryQueueManager(TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { + this.mainEventBus = mainEventBus; this.factory = new DefaultEventQueueFactory(); this.taskStateProvider = taskStateProvider; } - // For testing with custom factory - public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider) { + // For testing/extensions with custom factory and MainEventBus + public InMemoryQueueManager(EventQueueFactory factory, TaskStateProvider taskStateProvider, MainEventBus mainEventBus) { this.factory = factory; this.taskStateProvider = taskStateProvider; + this.mainEventBus = mainEventBus; } @Override @@ -101,7 +105,6 @@ public EventQueue createOrTap(String taskId) { EventQueue newQueue = null; if (existing == null) { // Use builder pattern for cleaner queue creation - // Use the new taskId-aware builder method if available newQueue = factory.builder(taskId).build(); // Make sure an existing queue has not been added in the meantime existing = queues.putIfAbsent(taskId, newQueue); @@ -128,6 +131,12 @@ public void awaitQueuePollerStart(EventQueue eventQueue) throws InterruptedExcep eventQueue.awaitQueuePollerStart(); } + @Override + public EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { + // Use the factory to ensure proper configuration (MainEventBus, callbacks, etc.) + return factory.builder(taskId); + } + @Override public int getActiveChildQueueCount(String taskId) { EventQueue queue = queues.get(taskId); @@ -142,6 +151,14 @@ public int getActiveChildQueueCount(String taskId) { return -1; } + @Override + public EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + return EventQueue.builder(mainEventBus) + .taskId(taskId) + .addOnCloseCallback(getCleanupCallback(taskId)) + .taskStateProvider(taskStateProvider); + } + /** * Get the cleanup callback that removes a queue from the map when it closes. * This is exposed so that subclasses (like ReplicatedQueueManager) can reuse @@ -181,11 +198,8 @@ public Runnable getCleanupCallback(String taskId) { private class DefaultEventQueueFactory implements EventQueueFactory { @Override public EventQueue.EventQueueBuilder builder(String taskId) { - // Return builder with callback that removes queue from map when closed - return EventQueue.builder() - .taskId(taskId) - .addOnCloseCallback(getCleanupCallback(taskId)) - .taskStateProvider(taskStateProvider); + // Delegate to the base builder creation method + return createBaseEventQueueBuilder(taskId); } } } diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBus.java b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java new file mode 100644 index 000000000..90080b1e2 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBus.java @@ -0,0 +1,42 @@ +package io.a2a.server.events; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import jakarta.enterprise.context.ApplicationScoped; + +@ApplicationScoped +public class MainEventBus { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBus.class); + private final BlockingQueue queue; + + public MainEventBus() { + this.queue = new LinkedBlockingDeque<>(); + } + + void submit(String taskId, EventQueue.MainQueue mainQueue, EventQueueItem item) { + try { + queue.put(new MainEventBusContext(taskId, mainQueue, item)); + LOGGER.debug("Submitted event for task {} to MainEventBus (queue size: {})", + taskId, queue.size()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted submitting to MainEventBus", e); + } + } + + MainEventBusContext take() throws InterruptedException { + LOGGER.debug("MainEventBus: Waiting to take event (current queue size: {})...", queue.size()); + MainEventBusContext context = queue.take(); + LOGGER.debug("MainEventBus: Took event for task {} (remaining queue size: {})", + context.taskId(), queue.size()); + return context; + } + + public int size() { + return queue.size(); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java new file mode 100644 index 000000000..292a60f21 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusContext.java @@ -0,0 +1,11 @@ +package io.a2a.server.events; + +import java.util.Objects; + +record MainEventBusContext(String taskId, EventQueue.MainQueue eventQueue, EventQueueItem eventQueueItem) { + MainEventBusContext { + Objects.requireNonNull(taskId, "taskId cannot be null"); + Objects.requireNonNull(eventQueue, "eventQueue cannot be null"); + Objects.requireNonNull(eventQueueItem, "eventQueueItem cannot be null"); + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java new file mode 100644 index 000000000..fbfc4aa0b --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -0,0 +1,385 @@ +package io.a2a.server.events; + +import java.util.concurrent.CompletableFuture; + +import jakarta.annotation.Nullable; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; + +import io.a2a.server.tasks.PushNotificationSender; +import io.a2a.server.tasks.TaskManager; +import io.a2a.server.tasks.TaskStore; +import io.a2a.spec.A2AServerException; +import io.a2a.spec.Event; +import io.a2a.spec.InternalError; +import io.a2a.spec.Message; +import io.a2a.spec.Task; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Background processor for the MainEventBus. + *

+ * This processor runs in a dedicated background thread, consuming events from the MainEventBus + * and performing two critical operations in order: + *

+ *
    + *
  1. Update TaskStore with event data (persistence FIRST)
  2. + *
  3. Distribute event to ChildQueues (clients see it AFTER persistence)
  4. + *
+ *

+ * This architecture ensures clients never receive events before they're persisted, + * eliminating race conditions and enabling reliable event replay. + *

+ *

+ * Note: This bean is eagerly initialized by {@link MainEventBusProcessorInitializer} + * to ensure the background thread starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessor implements Runnable { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessor.class); + + /** + * Callback for testing synchronization with async event processing. + * Default is NOOP to avoid null checks in production code. + * Tests can inject their own callback via setCallback(). + */ + private volatile MainEventBusProcessorCallback callback = MainEventBusProcessorCallback.NOOP; + + /** + * Optional executor for push notifications. + * If null, uses default ForkJoinPool (async). + * Tests can inject a synchronous executor to ensure deterministic ordering. + */ + private volatile @Nullable java.util.concurrent.Executor pushNotificationExecutor = null; + + private final MainEventBus eventBus; + + private final TaskStore taskStore; + + private final PushNotificationSender pushSender; + + private final QueueManager queueManager; + + private volatile boolean running = true; + private @Nullable Thread processorThread; + + @Inject + public MainEventBusProcessor(MainEventBus eventBus, TaskStore taskStore, PushNotificationSender pushSender, QueueManager queueManager) { + this.eventBus = eventBus; + this.taskStore = taskStore; + this.pushSender = pushSender; + this.queueManager = queueManager; + } + + /** + * Set a callback for testing synchronization with async event processing. + *

+ * This is primarily intended for tests that need to wait for event processing to complete. + * Pass null to reset to the default NOOP callback. + *

+ * + * @param callback the callback to invoke during event processing, or null for NOOP + */ + public void setCallback(MainEventBusProcessorCallback callback) { + this.callback = callback != null ? callback : MainEventBusProcessorCallback.NOOP; + } + + /** + * Set a custom executor for push notifications (primarily for testing). + *

+ * By default, push notifications are sent asynchronously using CompletableFuture.runAsync() + * with the default ForkJoinPool. For tests that need deterministic ordering of push + * notifications, inject a synchronous executor that runs tasks immediately on the calling thread. + *

+ * Example synchronous executor for tests: + *
{@code
+     * Executor syncExecutor = Runnable::run;
+     * mainEventBusProcessor.setPushNotificationExecutor(syncExecutor);
+     * }
+ * + * @param executor the executor to use for push notifications, or null to use default ForkJoinPool + */ + public void setPushNotificationExecutor(java.util.concurrent.Executor executor) { + this.pushNotificationExecutor = executor; + } + + @PostConstruct + void start() { + processorThread = new Thread(this, "MainEventBusProcessor"); + processorThread.setDaemon(true); // Allow JVM to exit even if this thread is running + processorThread.start(); + LOGGER.info("MainEventBusProcessor started"); + } + + /** + * No-op method to force CDI proxy resolution and ensure @PostConstruct has been called. + * Called by MainEventBusProcessorInitializer during application startup. + */ + public void ensureStarted() { + // Method intentionally empty - just forces proxy resolution + } + + @PreDestroy + void stop() { + LOGGER.info("MainEventBusProcessor stopping..."); + running = false; + if (processorThread != null) { + processorThread.interrupt(); + try { + long start = System.currentTimeMillis(); + processorThread.join(5000); // Wait up to 5 seconds + long elapsed = System.currentTimeMillis() - start; + LOGGER.info("MainEventBusProcessor thread stopped in {}ms", elapsed); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.warn("Interrupted while waiting for MainEventBusProcessor thread to stop"); + } + } + LOGGER.info("MainEventBusProcessor stopped"); + } + + @Override + public void run() { + LOGGER.info("MainEventBusProcessor processing loop started"); + while (running) { + try { + LOGGER.debug("MainEventBusProcessor: Waiting for event from MainEventBus..."); + MainEventBusContext context = eventBus.take(); + LOGGER.debug("MainEventBusProcessor: Retrieved event for task {} from MainEventBus", + context.taskId()); + processEvent(context); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.info("MainEventBusProcessor interrupted, shutting down"); + break; + } catch (Exception e) { + LOGGER.error("Error processing event from MainEventBus", e); + // Continue processing despite errors + } + } + LOGGER.info("MainEventBusProcessor processing loop ended"); + } + + private void processEvent(MainEventBusContext context) { + String taskId = context.taskId(); + Event event = context.eventQueueItem().getEvent(); + // MainEventBus.submit() guarantees this is always a MainQueue + EventQueue.MainQueue mainQueue = (EventQueue.MainQueue) context.eventQueue(); + + LOGGER.debug("MainEventBusProcessor: Processing event for task {}: {}", + taskId, event.getClass().getSimpleName()); + + Event eventToDistribute = null; + try { + // Step 1: Update TaskStore FIRST (persistence before clients see it) + // If this throws, we distribute an error to ensure "persist before client visibility" + + try { + boolean isReplicated = context.eventQueueItem().isReplicated(); + boolean isFinal = updateTaskStore(taskId, event, isReplicated); + + eventToDistribute = event; // Success - distribute original event + + // Trigger replication AFTER successful persistence + // SKIP replication if task is final - ReplicatedQueueManager handles this via TaskFinalizedEvent + // to ensure final Task is sent before poison pill (QueueClosedEvent) + if (!isFinal) { + EventEnqueueHook hook = mainQueue.getEnqueueHook(); + if (hook != null) { + LOGGER.debug("Triggering replication hook for task {} after successful persistence", taskId); + hook.onEnqueue(context.eventQueueItem()); + } + } else { + LOGGER.debug("Task {} is final - skipping replication hook (handled by ReplicatedQueueManager)", taskId); + } + } catch (InternalError e) { + // Persistence failed - create error event to distribute instead + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = e; + } catch (Exception e) { + LOGGER.error("Failed to persist event for task {}, distributing error to clients", taskId, e); + String errorMessage = "Failed to persist event: " + e.getMessage(); + eventToDistribute = new InternalError(errorMessage); + } + + // Step 2: Send push notification AFTER successful persistence + if (eventToDistribute == event) { + // Capture task state immediately after persistence, before going async + // This ensures we send the task as it existed when THIS event was processed, + // not whatever state might exist later when the async callback executes + Task taskSnapshot = taskStore.get(taskId); + if (taskSnapshot != null) { + sendPushNotification(taskId, taskSnapshot); + } else { + LOGGER.warn("Task {} not found in TaskStore after successful persistence, skipping push notification", taskId); + } + } + + // Step 3: Then distribute to ChildQueues (clients see either event or error AFTER persistence attempt) + if (eventToDistribute == null) { + LOGGER.error("MainEventBusProcessor: eventToDistribute is NULL for task {} - this should never happen!", taskId); + eventToDistribute = new InternalError("Internal error: event processing failed"); + } + + int childCount = mainQueue.getChildCount(); + LOGGER.debug("MainEventBusProcessor: Distributing {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + // Create new EventQueueItem with the event to distribute (original or error) + EventQueueItem itemToDistribute = new LocalEventQueueItem(eventToDistribute); + mainQueue.distributeToChildren(itemToDistribute); + LOGGER.debug("MainEventBusProcessor: Distributed {} to {} children for task {}", + eventToDistribute.getClass().getSimpleName(), childCount, taskId); + + LOGGER.debug("MainEventBusProcessor: Completed processing event for task {}", taskId); + + } finally { + try { + // Step 4: Notify callback after all processing is complete + // Call callback with the distributed event (original or error) + if (eventToDistribute != null) { + callback.onEventProcessed(taskId, eventToDistribute); + + // Step 5: If this is a final event, notify task finalization + // Only for successful persistence (not for errors) + if (eventToDistribute == event && isFinalEvent(event)) { + callback.onTaskFinalized(taskId); + } + } + } finally { + // ALWAYS release semaphore, even if processing fails + // Balances the acquire() in MainQueue.enqueueEvent() + mainQueue.releaseSemaphore(); + } + } + } + + /** + * Updates TaskStore using TaskManager.process(). + *

+ * Creates a temporary TaskManager instance for this event and delegates to its process() method, + * which handles all event types (Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent). + * This leverages existing TaskManager logic for status updates, artifact appending, message history, etc. + *

+ *

+ * If persistence fails, the exception is propagated to processEvent() which distributes an + * InternalError to clients instead of the original event, ensuring "persist before visibility". + * See Gemini's comment: https://github.com/a2aproject/a2a-java/pull/515#discussion_r2604621833 + *

+ * + * @param taskId the task ID + * @param event the event to persist + * @return true if the task reached a final state, false otherwise + * @throws InternalError if persistence fails + */ + private boolean updateTaskStore(String taskId, Event event, boolean isReplicated) throws InternalError { + try { + // Extract contextId from event (all relevant events have it) + String contextId = extractContextId(event); + + // Create temporary TaskManager instance for this event + TaskManager taskManager = new TaskManager(taskId, contextId, taskStore, null); + + // Use TaskManager.process() - handles all event types with existing logic + boolean isFinal = taskManager.process(event, isReplicated); + LOGGER.debug("TaskStore updated via TaskManager.process() for task {}: {} (final: {}, replicated: {})", + taskId, event.getClass().getSimpleName(), isFinal, isReplicated); + return isFinal; + } catch (InternalError e) { + LOGGER.error("Error updating TaskStore via TaskManager for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw e; + } catch (Exception e) { + LOGGER.error("Unexpected error updating TaskStore for task {}", taskId, e); + // Rethrow to prevent distributing unpersisted event to clients + throw new InternalError("TaskStore persistence failed: " + e.getMessage()); + } + } + + /** + * Sends push notification for the task AFTER persistence. + *

+ * This is called after updateTaskStore() to ensure the notification contains + * the latest persisted state, avoiding race conditions. + *

+ *

+ * CRITICAL: Push notifications are sent asynchronously in the background + * to avoid blocking event distribution to ChildQueues. The 83ms overhead from + * PushNotificationSender.sendNotification() was causing streaming delays. + *

+ *

+ * IMPORTANT: The task parameter is a snapshot captured immediately after + * persistence. This ensures we send the task state as it existed when THIS event + * was processed, not whatever state might exist in TaskStore when the async + * callback executes (subsequent events may have already updated the store). + *

+ *

+ * NOTE: Tests can inject a synchronous executor via setPushNotificationExecutor() + * to ensure deterministic ordering of push notifications in the test environment. + *

+ * + * @param taskId the task ID + * @param task the task snapshot to send (captured immediately after persistence) + */ + private void sendPushNotification(String taskId, Task task) { + Runnable pushTask = () -> { + try { + if (task != null) { + LOGGER.debug("Sending push notification for task {}", taskId); + pushSender.sendNotification(task); + } else { + LOGGER.debug("Skipping push notification - task snapshot is null for task {}", taskId); + } + } catch (Exception e) { + LOGGER.error("Error sending push notification for task {}", taskId, e); + // Don't rethrow - push notifications are best-effort + } + }; + + // Use custom executor if set (for tests), otherwise use default ForkJoinPool (async) + if (pushNotificationExecutor != null) { + pushNotificationExecutor.execute(pushTask); + } else { + CompletableFuture.runAsync(pushTask); + } + } + + /** + * Extracts contextId from an event. + * Returns null if the event type doesn't have a contextId (e.g., Message). + */ + @Nullable + private String extractContextId(Event event) { + if (event instanceof Task task) { + return task.contextId(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.contextId(); + } else if (event instanceof TaskArtifactUpdateEvent artifactUpdate) { + return artifactUpdate.contextId(); + } + // Message and other events don't have contextId + return null; + } + + /** + * Checks if an event represents a final task state. + * + * @param event the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(Event event) { + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); + } + return false; + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java new file mode 100644 index 000000000..b0a9adbce --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorCallback.java @@ -0,0 +1,66 @@ +package io.a2a.server.events; + +import io.a2a.spec.Event; + +/** + * Callback interface for MainEventBusProcessor events. + *

+ * This interface is primarily intended for testing, allowing tests to synchronize + * with the asynchronous MainEventBusProcessor. Production code should not rely on this. + *

+ * Usage in tests: + *
+ * {@code
+ * @Inject
+ * MainEventBusProcessor processor;
+ *
+ * @BeforeEach
+ * void setUp() {
+ *     CountDownLatch latch = new CountDownLatch(3);
+ *     processor.setCallback(new MainEventBusProcessorCallback() {
+ *         public void onEventProcessed(String taskId, Event event) {
+ *             latch.countDown();
+ *         }
+ *     });
+ * }
+ *
+ * @AfterEach
+ * void tearDown() {
+ *     processor.setCallback(null); // Reset to NOOP
+ * }
+ * }
+ * 
+ */ +public interface MainEventBusProcessorCallback { + + /** + * Called after an event has been fully processed (persisted, notification sent, distributed to children). + * + * @param taskId the task ID + * @param event the event that was processed + */ + void onEventProcessed(String taskId, Event event); + + /** + * Called when a task reaches a final state (COMPLETED, FAILED, CANCELED, REJECTED). + * + * @param taskId the task ID that was finalized + */ + void onTaskFinalized(String taskId); + + /** + * No-op implementation that does nothing. + * Used as the default callback to avoid null checks. + */ + MainEventBusProcessorCallback NOOP = new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + // No-op + } + + @Override + public void onTaskFinalized(String taskId) { + // No-op + } + }; +} diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java new file mode 100644 index 000000000..ba4b300be --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessorInitializer.java @@ -0,0 +1,43 @@ +package io.a2a.server.events; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.context.Initialized; +import jakarta.enterprise.event.Observes; +import jakarta.inject.Inject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Portable CDI initializer for MainEventBusProcessor. + *

+ * This bean observes the ApplicationScoped initialization event and injects + * MainEventBusProcessor, which triggers its eager creation and starts the background thread. + *

+ *

+ * This approach is portable across all Jakarta CDI implementations (Weld, OpenWebBeans, Quarkus, etc.) + * and ensures MainEventBusProcessor starts automatically when the application starts. + *

+ */ +@ApplicationScoped +public class MainEventBusProcessorInitializer { + private static final Logger LOGGER = LoggerFactory.getLogger(MainEventBusProcessorInitializer.class); + + @Inject + MainEventBusProcessor processor; + + /** + * Observes ApplicationScoped initialization to force eager creation of MainEventBusProcessor. + * The injection of MainEventBusProcessor in this bean triggers its creation, and calling + * ensureStarted() forces the CDI proxy to be resolved, which ensures @PostConstruct has been + * called and the background thread is running. + */ + void onStart(@Observes @Initialized(ApplicationScoped.class) Object event) { + if (processor != null) { + // Force proxy resolution to ensure @PostConstruct has been called + processor.ensureStarted(); + LOGGER.info("MainEventBusProcessor initialized and started"); + } else { + LOGGER.error("MainEventBusProcessor is null - initialization failed!"); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/events/QueueManager.java b/server-common/src/main/java/io/a2a/server/events/QueueManager.java index 01e754fcb..4ad30f0cb 100644 --- a/server-common/src/main/java/io/a2a/server/events/QueueManager.java +++ b/server-common/src/main/java/io/a2a/server/events/QueueManager.java @@ -177,7 +177,31 @@ public interface QueueManager { * @return a builder for creating event queues */ default EventQueue.EventQueueBuilder getEventQueueBuilder(String taskId) { - return EventQueue.builder(); + throw new UnsupportedOperationException( + "QueueManager implementations must override getEventQueueBuilder() to provide MainEventBus" + ); + } + + /** + * Creates a base EventQueueBuilder with standard configuration for this QueueManager. + * This method provides the foundation for creating event queues with proper configuration + * (MainEventBus, TaskStateProvider, cleanup callbacks, etc.). + *

+ * QueueManager implementations that use custom factories can call this method directly + * to get the base builder without going through the factory (which could cause infinite + * recursion if the factory delegates back to getEventQueueBuilder()). + *

+ *

+ * Callers can then add additional configuration (hooks, callbacks) before building the queue. + *

+ * + * @param taskId the task ID for the queue + * @return a builder with base configuration specific to this QueueManager implementation + */ + default EventQueue.EventQueueBuilder createBaseEventQueueBuilder(String taskId) { + throw new UnsupportedOperationException( + "QueueManager implementations must override createBaseEventQueueBuilder() to provide MainEventBus" + ); } /** diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 002acbafd..77ba83b8b 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -11,13 +11,13 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -35,13 +35,15 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.events.QueueManager; -import io.a2a.server.events.TaskQueueExistsException; import io.a2a.server.tasks.PushNotificationConfigStore; import io.a2a.server.tasks.PushNotificationSender; import io.a2a.server.tasks.ResultAggregator; import io.a2a.server.tasks.TaskManager; import io.a2a.server.tasks.TaskStore; +import io.a2a.server.util.async.EventConsumerExecutorProducer.EventConsumerExecutor; import io.a2a.server.util.async.Internal; import io.a2a.spec.A2AError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; @@ -64,6 +66,7 @@ import io.a2a.spec.TaskPushNotificationConfig; import io.a2a.spec.TaskQueryParams; import io.a2a.spec.TaskState; +import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.UnsupportedOperationError; import org.jspecify.annotations.NonNull; import org.jspecify.annotations.Nullable; @@ -122,7 +125,6 @@ *
  • {@link EventConsumer} polls and processes events on Vert.x worker thread
  • *
  • Queue closes automatically on final event (COMPLETED/FAILED/CANCELED)
  • *
  • Cleanup waits for both agent execution AND event consumption to complete
  • - *
  • Background tasks tracked via {@link #trackBackgroundTask(CompletableFuture)}
  • * * *

    Threading Model

    @@ -179,6 +181,13 @@ public class DefaultRequestHandler implements RequestHandler { private static final Logger LOGGER = LoggerFactory.getLogger(DefaultRequestHandler.class); + /** + * Separate logger for thread statistics diagnostic logging. + * This allows independent control of verbose thread pool monitoring without affecting + * general request handler logging. Enable with: logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG + */ + private static final Logger THREAD_STATS_LOGGER = LoggerFactory.getLogger("io.a2a.server.diagnostics.ThreadStats"); + private static final String A2A_BLOCKING_AGENT_TIMEOUT_SECONDS = "a2a.blocking.agent.timeout.seconds"; private static final String A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS = "a2a.blocking.consumption.timeout.seconds"; @@ -214,13 +223,14 @@ public class DefaultRequestHandler implements RequestHandler { private TaskStore taskStore; private QueueManager queueManager; private PushNotificationConfigStore pushConfigStore; - private PushNotificationSender pushSender; + private MainEventBusProcessor mainEventBusProcessor; private Supplier requestContextBuilder; private final ConcurrentMap> runningAgents = new ConcurrentHashMap<>(); - private final Set> backgroundTasks = ConcurrentHashMap.newKeySet(); + private Executor executor; + private Executor eventConsumerExecutor; /** * No-args constructor for CDI proxy creation. @@ -234,21 +244,25 @@ protected DefaultRequestHandler() { this.taskStore = null; this.queueManager = null; this.pushConfigStore = null; - this.pushSender = null; + this.mainEventBusProcessor = null; this.requestContextBuilder = null; this.executor = null; + this.eventConsumerExecutor = null; } @Inject public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, @Internal Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + @Internal Executor executor, + @EventConsumerExecutor Executor eventConsumerExecutor) { this.agentExecutor = agentExecutor; this.taskStore = taskStore; this.queueManager = queueManager; this.pushConfigStore = pushConfigStore; - this.pushSender = pushSender; + this.mainEventBusProcessor = mainEventBusProcessor; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; // TODO In Python this is also a constructor parameter defaulting to this SimpleRequestContextBuilder // implementation if the parameter is null. Skip that for now, since otherwise I get CDI errors, and // I am unsure about the correct scope. @@ -264,16 +278,20 @@ void initConfig() { configProvider.getValue(A2A_BLOCKING_CONSUMPTION_TIMEOUT_SECONDS)); } + /** * For testing */ public static DefaultRequestHandler create(AgentExecutor agentExecutor, TaskStore taskStore, QueueManager queueManager, PushNotificationConfigStore pushConfigStore, - PushNotificationSender pushSender, Executor executor) { + MainEventBusProcessor mainEventBusProcessor, + Executor executor, Executor eventConsumerExecutor) { DefaultRequestHandler handler = - new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, pushSender, executor); + new DefaultRequestHandler(agentExecutor, taskStore, queueManager, pushConfigStore, + mainEventBusProcessor, executor, eventConsumerExecutor); handler.agentCompletionTimeoutSeconds = 5; handler.consumptionCompletionTimeoutSeconds = 2; + return handler; } @@ -359,12 +377,9 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); - EventQueue queue = queueManager.tap(task.id()); - if (queue == null) { - queue = queueManager.getEventQueueBuilder(task.id()).build(); - } + EventQueue queue = queueManager.createOrTap(task.id()); agentExecutor.cancel( requestContextBuilder.get() .setTaskId(task.id()) @@ -395,28 +410,41 @@ public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws @Override public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws A2AError { LOGGER.debug("onMessageSend - task: {}; context {}", params.message().taskId(), params.message().contextId()); + + // Build MessageSendSetup which creates RequestContext with real taskId (auto-generated if needed) MessageSendSetup mss = initMessageSend(params, context); - String taskId = mss.requestContext.getTaskId(); - LOGGER.debug("Request context taskId: {}", taskId); + // Use the taskId from RequestContext for queue management (no temp ID needed!) + // RequestContext.build() guarantees taskId is non-null via checkOrGenerateTaskId() + String queueTaskId = java.util.Objects.requireNonNull( + mss.requestContext.getTaskId(), "TaskId must be non-null after RequestContext.build()"); + LOGGER.debug("Queue taskId: {}", queueTaskId); - if (taskId == null) { - throw new io.a2a.spec.InternalError("Task ID is null in onMessageSend"); - } - EventQueue queue = queueManager.createOrTap(taskId); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + // Create queue with real taskId (no tempId parameter needed) + EventQueue queue = queueManager.createOrTap(queueTaskId); + final java.util.concurrent.atomic.AtomicReference<@NonNull String> taskId = new java.util.concurrent.atomic.AtomicReference<>(queueTaskId); + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); + // Default to blocking=false per A2A spec (return after task creation) boolean blocking = params.configuration() != null && Boolean.TRUE.equals(params.configuration().blocking()); + // Log blocking behavior from client request + if (params.configuration() != null && params.configuration().blocking() != null) { + LOGGER.debug("DefaultRequestHandler: Client requested blocking={} for task {}", + params.configuration().blocking(), taskId.get()); + } else if (params.configuration() != null) { + LOGGER.debug("DefaultRequestHandler: Client sent configuration but blocking=null, using default blocking={} for task {}", blocking, taskId.get()); + } else { + LOGGER.debug("DefaultRequestHandler: Client sent no configuration, using default blocking={} for task {}", blocking, taskId.get()); + } + LOGGER.debug("DefaultRequestHandler: Final blocking decision: {} for task {}", blocking, taskId.get()); + boolean interruptedOrNonBlocking = false; - EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId, mss.requestContext, queue); + EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); ResultAggregator.EventTypeAndInterrupt etai = null; EventKind kind = null; // Declare outside try block so it's in scope for return try { - // Create callback for push notifications during background event processing - Runnable pushNotificationCallback = () -> sendPushNotification(taskId, resultAggregator); - EventConsumer consumer = new EventConsumer(queue); // This callback must be added before we start consuming. Otherwise, @@ -424,7 +452,7 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); // Get agent future before consuming (for blocking calls to wait for agent completion) - CompletableFuture agentFuture = runningAgents.get(taskId); + CompletableFuture agentFuture = runningAgents.get(queueTaskId); etai = resultAggregator.consumeAndBreakOnInterrupt(consumer, blocking); if (etai == null) { @@ -432,7 +460,8 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte throw new InternalError("No result"); } interruptedOrNonBlocking = etai.interrupted(); - LOGGER.debug("Was interrupted or non-blocking: {}", interruptedOrNonBlocking); + LOGGER.debug("DefaultRequestHandler: interruptedOrNonBlocking={} (blocking={}, eventType={})", + interruptedOrNonBlocking, blocking, kind != null ? kind.getClass().getSimpleName() : null); // For blocking calls that were interrupted (returned on first event), // wait for agent execution and event processing BEFORE returning to client. @@ -441,30 +470,36 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // during the consumption loop itself. kind = etai.eventType(); + // No ID switching needed - agent uses context.getTaskId() which is the same as queue key + // Store push notification config for newly created tasks (mirrors streaming logic) // Only for NEW tasks - existing tasks are handled by initMessageSend() if (mss.task() == null && kind instanceof Task createdTask && shouldAddPushInfo(params)) { - LOGGER.debug("Storing push notification config for new task {}", createdTask.id()); + LOGGER.debug("Storing push notification config for new task {} (original taskId from params: {})", + createdTask.id(), params.message().taskId()); pushConfigStore.setInfo(createdTask.id(), params.configuration().pushNotificationConfig()); } if (blocking && interruptedOrNonBlocking) { - // For blocking calls: ensure all events are processed before returning - // Order of operations is critical to avoid circular dependency: + // For blocking calls: ensure all events are persisted to TaskStore before returning + // Order of operations is critical to avoid circular dependency and race conditions: // 1. Wait for agent to finish enqueueing events // 2. Close the queue to signal consumption can complete // 3. Wait for consumption to finish processing events - // 4. Fetch final task state from TaskStore + // 4. Wait for MainEventBusProcessor to persist final state to TaskStore + // 5. Fetch final task state from TaskStore (now guaranteed persisted) + LOGGER.debug("DefaultRequestHandler: Entering blocking fire-and-forget handling for task {}", taskId.get()); try { // Step 1: Wait for agent to finish (with configurable timeout) if (agentFuture != null) { try { agentFuture.get(agentCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Agent completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent completed for task {}", taskId.get()); } catch (java.util.concurrent.TimeoutException e) { // Agent still running after timeout - that's fine, events already being processed - LOGGER.debug("Agent still running for task {} after {}s", taskId, agentCompletionTimeoutSeconds); + LOGGER.debug("DefaultRequestHandler: Step 1 - Agent still running for task {} after {}s timeout", + taskId.get(), agentCompletionTimeoutSeconds); } } @@ -472,55 +507,77 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte // For fire-and-forget tasks, there's no final event, so we need to close the queue // This allows EventConsumer.consumeAll() to exit queue.close(false, false); // graceful close, don't notify parent yet - LOGGER.debug("Closed queue for task {} to allow consumption completion", taskId); + LOGGER.debug("DefaultRequestHandler: Step 2 - Closed queue for task {} to allow consumption completion", taskId.get()); // Step 3: Wait for consumption to complete (now that queue is closed) if (etai.consumptionFuture() != null) { etai.consumptionFuture().get(consumptionCompletionTimeoutSeconds, SECONDS); - LOGGER.debug("Consumption completed for task {}", taskId); + LOGGER.debug("DefaultRequestHandler: Step 3 - Consumption completed for task {}", taskId.get()); } + + // NOTE: We do NOT wait for MainEventBusProcessor to finalize the task here. + // This would require blocking, which breaks gRPC (Vert.x event loop). + // In practice, events are processed within milliseconds, so the race + // condition where TaskStore is not fully updated is minimal. + // For platform-agnostic SDK design, we accept this minor race condition + // rather than blocking event loop threads. + } catch (InterruptedException e) { Thread.currentThread().interrupt(); - String msg = String.format("Error waiting for task %s completion", taskId); + String msg = String.format("Error waiting for task %s completion", taskId.get()); LOGGER.warn(msg, e); throw new InternalError(msg); } catch (java.util.concurrent.ExecutionException e) { - String msg = String.format("Error during task %s execution", taskId); + String msg = String.format("Error during task %s execution", taskId.get()); LOGGER.warn(msg, e.getCause()); throw new InternalError(msg); - } catch (java.util.concurrent.TimeoutException e) { - String msg = String.format("Timeout waiting for consumption to complete for task %s", taskId); - LOGGER.warn(msg, taskId); + } catch (TimeoutException e) { + // Timeout from consumption future.get() - different from finalization timeout + String msg = String.format("Timeout waiting for task %s consumption", taskId.get()); + LOGGER.warn(msg, e); throw new InternalError(msg); } - // Step 4: Fetch the final task state from TaskStore (all events have been processed) - // taskId is guaranteed non-null here (checked earlier) - String nonNullTaskId = taskId; + // Step 5: Fetch the final task state from TaskStore (now guaranteed persisted) + String nonNullTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); Task updatedTask = taskStore.get(nonNullTaskId); if (updatedTask != null) { kind = updatedTask; - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Fetched final task for {} with state {} and {} artifacts", - nonNullTaskId, updatedTask.status().state(), - updatedTask.artifacts().size()); - } + LOGGER.debug("DefaultRequestHandler: Step 5 - Fetched final task for {} with state {} and {} artifacts", + taskId.get(), updatedTask.status().state(), + updatedTask.artifacts().size()); + } else { + LOGGER.warn("DefaultRequestHandler: Step 5 - Task {} not found in TaskStore!", taskId.get()); } } - if (kind instanceof Task taskResult && !taskId.equals(taskResult.id())) { + String finalTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (kind instanceof Task taskResult && !finalTaskId.equals(taskResult.id())) { throw new InternalError("Task ID mismatch in agent response"); } - - // Send push notification after initial return (for both blocking and non-blocking) - pushNotificationCallback.run(); } finally { + // For non-blocking calls: close ChildQueue IMMEDIATELY to free EventConsumer thread + // CRITICAL: Must use immediate=true to clear the local queue, otherwise EventConsumer + // continues polling until queue drains naturally, holding executor thread. + // Immediate close clears pending events and triggers EventQueueClosedException on next poll. + // Events continue flowing through MainQueue → MainEventBus → TaskStore. + if (!blocking && etai != null && etai.interrupted()) { + LOGGER.debug("DefaultRequestHandler: Non-blocking call in finally - closing ChildQueue IMMEDIATELY for task {} to free EventConsumer", taskId.get()); + queue.close(true); // immediate=true: clear queue and free EventConsumer + } + // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId, runningAgents.size()); + CompletableFuture agentFuture = runningAgents.remove(queueTaskId); + String cleanupTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", cleanupTaskId, runningAgents.size()); - // Track cleanup as background task to avoid blocking Vert.x threads + // Cleanup as background task to avoid blocking Vert.x threads // Pass the consumption future to ensure cleanup waits for background consumption to complete - trackBackgroundTask(cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, taskId, queue, false)); + cleanupProducer(agentFuture, etai != null ? etai.consumptionFuture() : null, cleanupTaskId, queue, false) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for task {}", taskId.get(), err); + } + }); } LOGGER.debug("Returning: {}", kind); @@ -530,29 +587,44 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte @Override public Flow.Publisher onMessageSendStream( MessageSendParams params, ServerCallContext context) throws A2AError { - LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}; backgroundTasks: {}", - params.message().taskId(), params.message().contextId(), runningAgents.size(), backgroundTasks.size()); + LOGGER.debug("onMessageSendStream START - task: {}; context: {}; runningAgents: {}", + params.message().taskId(), params.message().contextId(), runningAgents.size()); + + // Build MessageSendSetup which creates RequestContext with real taskId (auto-generated if needed) MessageSendSetup mss = initMessageSend(params, context); - @Nullable String initialTaskId = mss.requestContext.getTaskId(); - // For streaming, taskId can be null initially (will be set when Task event arrives) - // Use a temporary ID for queue creation if needed - String queueTaskId = initialTaskId != null ? initialTaskId : "temp-" + java.util.UUID.randomUUID(); + // Use the taskId from RequestContext for queue management (no temp ID needed!) + // RequestContext.build() guarantees taskId is non-null via checkOrGenerateTaskId() + String queueTaskId = java.util.Objects.requireNonNull( + mss.requestContext.getTaskId(), "TaskId must be non-null after RequestContext.build()"); + final AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); - AtomicReference<@NonNull String> taskId = new AtomicReference<>(queueTaskId); - @SuppressWarnings("NullAway") - EventQueue queue = queueManager.createOrTap(taskId.get()); + // Create queue with real taskId (no tempId parameter needed) + EventQueue queue = queueManager.createOrTap(queueTaskId); LOGGER.debug("Created/tapped queue for task {}: {}", taskId.get(), queue); - ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor); + + // Store push notification config SYNCHRONOUSLY for new tasks before agent starts + // This ensures config is available when MainEventBusProcessor sends push notifications + // For existing tasks, config is stored in initMessageSend() + if (mss.task() == null && shouldAddPushInfo(params)) { + // Satisfy Nullaway + Objects.requireNonNull(taskId.get(), "taskId was null"); + LOGGER.debug("Storing push notification config for new streaming task {} EARLY (original taskId from params: {})", + taskId.get(), params.message().taskId()); + pushConfigStore.setInfo(taskId.get(), params.configuration().pushNotificationConfig()); + } + + ResultAggregator resultAggregator = new ResultAggregator(mss.taskManager, null, executor, eventConsumerExecutor); EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(queueTaskId, mss.requestContext, queue); // Move consumer creation and callback registration outside try block - // so consumer is available for background consumption on client disconnect EventConsumer consumer = new EventConsumer(queue); producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback()); - AtomicBoolean backgroundConsumeStarted = new AtomicBoolean(false); + // Store cancel callback in context for closeHandler to access + // When client disconnects, closeHandler can call this to stop EventConsumer polling loop + context.setEventConsumerCancelCallback(consumer::cancel); try { Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); @@ -562,36 +634,13 @@ public Flow.Publisher onMessageSendStream( processor(createTubeConfig(), results, ((errorConsumer, item) -> { Event event = item.getEvent(); if (event instanceof Task createdTask) { - if (!Objects.equals(taskId.get(), createdTask.id())) { - errorConsumer.accept(new InternalError("Task ID mismatch in agent response")); - } - - // TODO the Python implementation no longer has the following block but removing it causes - // failures here - try { - queueManager.add(createdTask.id(), queue); - taskId.set(createdTask.id()); - } catch (TaskQueueExistsException e) { - // TODO Log - } - if (pushConfigStore != null && - params.configuration() != null && - params.configuration().pushNotificationConfig() != null) { - - pushConfigStore.setInfo( - createdTask.id(), - params.configuration().pushNotificationConfig()); - } - - } - String currentTaskId = taskId.get(); - if (pushSender != null && currentTaskId != null) { - EventKind latest = resultAggregator.getCurrentResult(); - if (latest instanceof Task latestTask) { - pushSender.sendNotification(latestTask); + // Verify task ID matches (should always match now - agent uses context.getTaskId()) + String currentId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); + if (!currentId.equals(createdTask.id())) { + errorConsumer.accept(new InternalError("Task ID mismatch: expected " + currentId + + " but got " + createdTask.id())); } } - return true; })); @@ -600,7 +649,8 @@ public Flow.Publisher onMessageSendStream( Flow.Publisher finalPublisher = convertingProcessor(eventPublisher, event -> (StreamingEventKind) event); - // Wrap publisher to detect client disconnect and continue background consumption + // Wrap publisher to detect client disconnect and immediately close ChildQueue + // This prevents ChildQueue backpressure from blocking MainEventBusProcessor return subscriber -> { String currentTaskId = taskId.get(); LOGGER.debug("Creating subscription wrapper for task {}", currentTaskId); @@ -621,8 +671,10 @@ public void request(long n) { @Override public void cancel() { - LOGGER.debug("Client cancelled subscription for task {}, starting background consumption", taskId.get()); - startBackgroundConsumption(); + LOGGER.debug("Client cancelled subscription for task {}, closing ChildQueue immediately", taskId.get()); + // Close ChildQueue immediately to prevent backpressure + // (clears queue and releases semaphore permits) + queue.close(true); // immediate=true subscription.cancel(); } }); @@ -647,8 +699,8 @@ public void onComplete() { subscriber.onComplete(); } catch (IllegalStateException e) { // Client already disconnected and response closed - this is expected - // for streaming responses where client disconnect triggers background - // consumption. Log and ignore. + // for streaming responses where client disconnect closes ChildQueue. + // Log and ignore. if (e.getMessage() != null && e.getMessage().contains("Response has already been written")) { LOGGER.debug("Client disconnected before onComplete, response already closed for task {}", taskId.get()); } else { @@ -656,36 +708,26 @@ public void onComplete() { } } } - - private void startBackgroundConsumption() { - if (backgroundConsumeStarted.compareAndSet(false, true)) { - LOGGER.debug("Starting background consumption for task {}", taskId.get()); - // Client disconnected: continue consuming and persisting events in background - CompletableFuture bgTask = CompletableFuture.runAsync(() -> { - try { - LOGGER.debug("Background consumption thread started for task {}", taskId.get()); - resultAggregator.consumeAll(consumer); - LOGGER.debug("Background consumption completed for task {}", taskId.get()); - } catch (Exception e) { - LOGGER.error("Error during background consumption for task {}", taskId.get(), e); - } - }, executor); - trackBackgroundTask(bgTask); - } else { - LOGGER.debug("Background consumption already started for task {}", taskId.get()); - } - } }); }; } finally { - LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}; backgroundTasks: {}", - taskId.get(), runningAgents.size(), backgroundTasks.size()); - - // Remove agent from map immediately to prevent accumulation - CompletableFuture agentFuture = runningAgents.remove(taskId.get()); - LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); - - trackBackgroundTask(cleanupProducer(agentFuture, null, Objects.requireNonNull(taskId.get()), queue, true)); + // Needed to satisfy Nullaway + String idOfTask = taskId.get(); + if (idOfTask != null) { + LOGGER.debug("onMessageSendStream FINALLY - task: {}; runningAgents: {}", + idOfTask, runningAgents.size()); + + // Remove agent from map immediately to prevent accumulation + CompletableFuture agentFuture = runningAgents.remove(idOfTask); + LOGGER.debug("Removed agent for task {} from runningAgents in finally block, size after: {}", taskId.get(), runningAgents.size()); + + cleanupProducer(agentFuture, null, idOfTask, queue, true) + .whenComplete((res, err) -> { + if (err != null) { + LOGGER.error("Error during async cleanup for streaming task {}", taskId.get(), err); + } + }); + } } } @@ -746,7 +788,7 @@ public Flow.Publisher onResubscribeToTask( } TaskManager taskManager = new TaskManager(task.id(), task.contextId(), taskStore, null); - ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor); + ResultAggregator resultAggregator = new ResultAggregator(taskManager, null, executor, eventConsumerExecutor); EventQueue queue = queueManager.tap(task.id()); LOGGER.debug("onResubscribeToTask - tapped queue: {}", queue != null ? System.identityHashCode(queue) : "null"); @@ -761,6 +803,13 @@ public Flow.Publisher onResubscribeToTask( queue = queueManager.createOrTap(task.id()); } + // Per A2A Protocol Spec 3.1.6 (Subscribe to Task): + // "The operation MUST return a Task object as the first event in the stream, + // representing the current state of the task at the time of subscription." + // Enqueue the current task state directly to this ChildQueue only (already persisted, no need for MainEventBus) + queue.enqueueEventLocalOnly(task); + LOGGER.debug("onResubscribeToTask - enqueued current task state as first event for taskId: {}", params.id()); + EventConsumer consumer = new EventConsumer(queue); Flow.Publisher results = resultAggregator.consumeAndEmit(consumer); LOGGER.debug("onResubscribeToTask - returning publisher for taskId: {}", params.id()); @@ -819,8 +868,7 @@ public void run() { LOGGER.debug("Agent execution starting for task {}", taskId); agentExecutor.execute(requestContext, queue); LOGGER.debug("Agent execution completed for task {}", taskId); - // No longer wait for queue poller to start - the consumer (which is guaranteed - // to be running on the Vert.x worker thread) will handle queue lifecycle. + // The consumer (running on the Vert.x worker thread) handles queue lifecycle. // This avoids blocking agent-executor threads waiting for worker threads. } }; @@ -833,8 +881,8 @@ public void run() { // Don't close queue here - let the consumer handle it via error callback // This ensures the consumer (which may not have started polling yet) gets the error } - // Queue lifecycle is now managed entirely by EventConsumer.consumeAll() - // which closes the queue on final events. No need to close here. + // Queue lifecycle is managed by EventConsumer.consumeAll() + // which closes the queue on final events. logThreadStats("AGENT COMPLETE END"); runnable.invokeDoneCallbacks(); }); @@ -843,47 +891,6 @@ public void run() { return runnable; } - private void trackBackgroundTask(CompletableFuture task) { - backgroundTasks.add(task); - LOGGER.debug("Tracking background task (total: {}): {}", backgroundTasks.size(), task); - - task.whenComplete((result, throwable) -> { - try { - if (throwable != null) { - // Unwrap CompletionException to check for CancellationException - Throwable cause = throwable; - if (throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null) { - cause = throwable.getCause(); - } - - if (cause instanceof java.util.concurrent.CancellationException) { - LOGGER.debug("Background task cancelled: {}", task); - } else { - LOGGER.error("Background task failed", throwable); - } - } - } finally { - backgroundTasks.remove(task); - LOGGER.debug("Removed background task (remaining: {}): {}", backgroundTasks.size(), task); - } - }); - } - - /** - * Wait for all background tasks to complete. - * Useful for testing to ensure cleanup completes before assertions. - * - * @return CompletableFuture that completes when all background tasks finish - */ - public CompletableFuture waitForBackgroundTasks() { - CompletableFuture[] tasks = backgroundTasks.toArray(new CompletableFuture[0]); - if (tasks.length == 0) { - return CompletableFuture.completedFuture(null); - } - LOGGER.debug("Waiting for {} background tasks to complete", tasks.length); - return CompletableFuture.allOf(tasks); - } - private CompletableFuture cleanupProducer(@Nullable CompletableFuture agentFuture, @Nullable CompletableFuture consumptionFuture, String taskId, EventQueue queue, boolean isStreaming) { LOGGER.debug("Starting cleanup for task {} (streaming={})", taskId, isStreaming); logThreadStats("CLEANUP START"); @@ -908,14 +915,20 @@ private CompletableFuture cleanupProducer(@Nullable CompletableFuture cleanupProducer(@Nullable CompletableFuture + * Enable independently with: {@code logging.level.io.a2a.server.diagnostics.ThreadStats=DEBUG} + *

    */ @SuppressWarnings("unused") // Used for debugging private void logThreadStats(String label) { // Early return if debug logging is not enabled to avoid overhead - if (!LOGGER.isDebugEnabled()) { + if (!THREAD_STATS_LOGGER.isDebugEnabled()) { return; } @@ -982,28 +1029,57 @@ private void logThreadStats(String label) { } int activeThreads = rootGroup.activeCount(); - LOGGER.debug("=== THREAD STATS: {} ===", label); - LOGGER.debug("Active threads: {}", activeThreads); - LOGGER.debug("Running agents: {}", runningAgents.size()); - LOGGER.debug("Background tasks: {}", backgroundTasks.size()); - LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); + // Count specific thread types + Thread[] threads = new Thread[activeThreads * 2]; + int count = rootGroup.enumerate(threads); + int eventConsumerThreads = 0; + int agentExecutorThreads = 0; + for (int i = 0; i < count; i++) { + if (threads[i] != null) { + String name = threads[i].getName(); + if (name.startsWith("a2a-event-consumer-")) { + eventConsumerThreads++; + } else if (name.startsWith("a2a-agent-executor-")) { + agentExecutorThreads++; + } + } + } + + THREAD_STATS_LOGGER.debug("=== THREAD STATS: {} ===", label); + THREAD_STATS_LOGGER.debug("Total active threads: {}", activeThreads); + THREAD_STATS_LOGGER.debug("EventConsumer threads: {}", eventConsumerThreads); + THREAD_STATS_LOGGER.debug("AgentExecutor threads: {}", agentExecutorThreads); + THREAD_STATS_LOGGER.debug("Running agents: {}", runningAgents.size()); + THREAD_STATS_LOGGER.debug("Queue manager active queues: {}", queueManager.getClass().getSimpleName()); // List running agents if (!runningAgents.isEmpty()) { - LOGGER.debug("Running agent tasks:"); + THREAD_STATS_LOGGER.debug("Running agent tasks:"); runningAgents.forEach((taskId, future) -> - LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") + THREAD_STATS_LOGGER.debug(" - Task {}: {}", taskId, future.isDone() ? "DONE" : "RUNNING") ); } - // List background tasks - if (!backgroundTasks.isEmpty()) { - LOGGER.debug("Background tasks:"); - backgroundTasks.forEach(task -> - LOGGER.debug(" - {}: {}", task, task.isDone() ? "DONE" : "RUNNING") - ); + THREAD_STATS_LOGGER.debug("=== END THREAD STATS ==="); + } + + /** + * Check if an event represents a final task state. + * + * @param eventKind the event to check + * @return true if the event represents a final state (COMPLETED, FAILED, CANCELED, REJECTED, UNKNOWN) + */ + private boolean isFinalEvent(EventKind eventKind) { + if (!(eventKind instanceof Event event)) { + return false; + } + if (event instanceof Task task) { + return task.status() != null && task.status().state() != null + && task.status().state().isFinal(); + } else if (event instanceof TaskStatusUpdateEvent statusUpdate) { + return statusUpdate.isFinal(); } - LOGGER.debug("=== END THREAD STATS ==="); + return false; } private record MessageSendSetup(TaskManager taskManager, @Nullable Task task, RequestContext requestContext) {} diff --git a/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java b/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java index 1e3ae1206..15f94d7e8 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/InMemoryTaskStore.java @@ -32,8 +32,9 @@ public class InMemoryTaskStore implements TaskStore, TaskStateProvider { private final ConcurrentMap tasks = new ConcurrentHashMap<>(); @Override - public void save(Task task) { + public void save(Task task, boolean isReplicated) { tasks.put(task.id(), task); + // InMemoryTaskStore doesn't fire TaskFinalizedEvent, so isReplicated is unused here } @Override diff --git a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java index 95684e199..506b3f3b6 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java +++ b/server-common/src/main/java/io/a2a/server/tasks/ResultAggregator.java @@ -14,9 +14,9 @@ import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueueItem; import io.a2a.spec.A2AError; -import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; import io.a2a.spec.EventKind; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; @@ -31,12 +31,14 @@ public class ResultAggregator { private final TaskManager taskManager; private final Executor executor; + private final Executor eventConsumerExecutor; private volatile @Nullable Message message; - public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor) { + public ResultAggregator(TaskManager taskManager, @Nullable Message message, Executor executor, Executor eventConsumerExecutor) { this.taskManager = taskManager; this.message = message; this.executor = executor; + this.eventConsumerExecutor = eventConsumerExecutor; } public @Nullable EventKind getCurrentResult() { @@ -49,20 +51,23 @@ public ResultAggregator(TaskManager taskManager, @Nullable Message message, Exec public Flow.Publisher consumeAndEmit(EventConsumer consumer) { Flow.Publisher allItems = consumer.consumeAll(); - // Process items conditionally - only save non-replicated events to database - return processor(createTubeConfig(), allItems, (errorConsumer, item) -> { - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(item.getEvent()); - } catch (A2AServerException e) { - errorConsumer.accept(e); - return false; - } - } - // Continue processing and emit (both replicated and non-replicated) + // Just stream events - no persistence needed + // TaskStore update moved to MainEventBusProcessor + Flow.Publisher processed = processor(createTubeConfig(), allItems, (errorConsumer, item) -> { + // Continue processing and emit all events return true; }); + + // Wrap the publisher to ensure subscription happens on eventConsumerExecutor + // This prevents EventConsumer polling loop from running on AgentExecutor threads + // which caused thread accumulation when those threads didn't timeout + return new Flow.Publisher() { + @Override + public void subscribe(Flow.Subscriber subscriber) { + // Submit subscription to eventConsumerExecutor to isolate polling work + eventConsumerExecutor.execute(() -> processed.subscribe(subscriber)); + } + }; } public EventKind consumeAll(EventConsumer consumer) throws A2AError { @@ -81,15 +86,7 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { return false; } } - // Only process non-replicated events to avoid duplicate database writes - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - error.set(e); - return false; - } - } + // TaskStore update moved to MainEventBusProcessor return true; }, error::set); @@ -113,18 +110,24 @@ public EventKind consumeAll(EventConsumer consumer) throws A2AError { public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, boolean blocking) throws A2AError { Flow.Publisher allItems = consumer.consumeAll(); AtomicReference message = new AtomicReference<>(); + AtomicReference capturedTask = new AtomicReference<>(); // Capture Task events AtomicBoolean interrupted = new AtomicBoolean(false); AtomicReference errorRef = new AtomicReference<>(); CompletableFuture completionFuture = new CompletableFuture<>(); // Separate future for tracking background consumption completion CompletableFuture consumptionCompletionFuture = new CompletableFuture<>(); + // Latch to ensure EventConsumer starts polling before we wait on completionFuture + java.util.concurrent.CountDownLatch pollingStarted = new java.util.concurrent.CountDownLatch(1); // CRITICAL: The subscription itself must run on a background thread to avoid blocking // the Vert.x worker thread. EventConsumer.consumeAll() starts a polling loop that // blocks in dequeueEventItem(), so we must subscribe from a background thread. - // Use the @Internal executor (not ForkJoinPool.commonPool) to avoid saturation - // during concurrent request bursts. + // Use the dedicated @EventConsumerExecutor (cached thread pool) which creates threads + // on demand for I/O-bound polling. Using the @Internal executor caused deadlock when + // pool exhausted (100+ concurrent queues but maxPoolSize=50). CompletableFuture.runAsync(() -> { + // Signal that polling is about to start + pollingStarted.countDown(); consumer( createTubeConfig(), allItems, @@ -146,25 +149,30 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, return false; } - // Process event through TaskManager - only for non-replicated events - if (!item.isReplicated()) { - try { - callTaskManagerProcess(event); - } catch (A2AServerException e) { - errorRef.set(e); - completionFuture.completeExceptionally(e); - return false; + // Capture Task events (especially for new tasks where taskManager.getTask() would return null) + // We capture the LATEST task to ensure we get the most up-to-date state + if (event instanceof Task t) { + Task previousTask = capturedTask.get(); + capturedTask.set(t); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Captured Task event: id={}, state={} (previous: {})", + t.id(), t.status().state(), + previousTask != null ? previousTask.id() + "/" + previousTask.status().state() : "none"); } } + // TaskStore update moved to MainEventBusProcessor + // Determine interrupt behavior boolean shouldInterrupt = false; - boolean continueInBackground = false; boolean isFinalEvent = (event instanceof Task task && task.status().state().isFinal()) || (event instanceof TaskStatusUpdateEvent tsue && tsue.isFinal()); boolean isAuthRequired = (event instanceof Task task && task.status().state() == TaskState.AUTH_REQUIRED) || (event instanceof TaskStatusUpdateEvent tsue && tsue.status().state() == TaskState.AUTH_REQUIRED); + LOGGER.debug("ResultAggregator: Evaluating interrupt (blocking={}, isFinal={}, isAuth={}, eventType={})", + blocking, isFinalEvent, isAuthRequired, event.getClass().getSimpleName()); + // Always interrupt on auth_required, as it needs external action. if (isAuthRequired) { // auth-required is a special state: the message should be @@ -174,20 +182,19 @@ public EventTypeAndInterrupt consumeAndBreakOnInterrupt(EventConsumer consumer, // new request is expected in order for the agent to make progress, // so the agent should exit. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (AUTH_REQUIRED)"); } else if (!blocking) { // For non-blocking calls, interrupt as soon as a task is available. shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (non-blocking)"); } else if (blocking) { // For blocking calls: Interrupt to free Vert.x thread, but continue in background // Python's async consumption doesn't block threads, but Java's does // So we interrupt to return quickly, then rely on background consumption - // DefaultRequestHandler will fetch the final state from TaskStore shouldInterrupt = true; - continueInBackground = true; + LOGGER.debug("ResultAggregator: Setting shouldInterrupt=true (blocking, isFinal={})", isFinalEvent); if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocking call for task {}: {} event, returning with background consumption", taskIdForLogging(), isFinalEvent ? "final" : "non-final"); @@ -195,14 +202,14 @@ else if (blocking) { } if (shouldInterrupt) { + LOGGER.debug("ResultAggregator: Interrupting consumption (setting interrupted=true)"); // Complete the future to unblock the main thread interrupted.set(true); completionFuture.complete(null); // For blocking calls, DON'T complete consumptionCompletionFuture here. // Let it complete naturally when subscription finishes (onComplete callback below). - // This ensures all events are processed and persisted to TaskStore before - // DefaultRequestHandler.cleanupProducer() proceeds with cleanup. + // This ensures all events are fully processed before cleanup. // // For non-blocking and auth-required calls, complete immediately to allow // cleanup to proceed while consumption continues in background. @@ -237,7 +244,16 @@ else if (blocking) { } } ); - }, executor); + }, eventConsumerExecutor); + + // Wait for EventConsumer to start polling before we wait for events + // This prevents race where agent enqueues events before EventConsumer starts + try { + pollingStarted.await(5, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new io.a2a.spec.InternalError("Interrupted waiting for EventConsumer to start"); + } // Wait for completion or interruption try { @@ -261,28 +277,30 @@ else if (blocking) { Utils.rethrow(error); } - EventKind eventType; - Message msg = message.get(); - if (msg != null) { - eventType = msg; - } else { - Task task = taskManager.getTask(); - if (task == null) { - throw new io.a2a.spec.InternalError("No task or message available after consuming events"); + // Return Message if captured, otherwise Task if captured, otherwise fetch from TaskStore + EventKind eventKind = message.get(); + if (eventKind == null) { + eventKind = capturedTask.get(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning capturedTask: id={}, state={}", t.id(), t.status().state()); } - eventType = task; + } + if (eventKind == null) { + eventKind = taskManager.getTask(); + if (LOGGER.isDebugEnabled() && eventKind instanceof Task t) { + LOGGER.debug("Returning task from TaskStore: id={}, state={}", t.id(), t.status().state()); + } + } + if (eventKind == null) { + throw new InternalError("Could not find a Task/Message for " + taskManager.getTaskId()); } return new EventTypeAndInterrupt( - eventType, + eventKind, interrupted.get(), consumptionCompletionFuture); } - private void callTaskManagerProcess(Event event) throws A2AServerException { - taskManager.process(event); - } - private String taskIdForLogging() { Task task = taskManager.getTask(); return task != null ? task.id() : "unknown"; diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java index fd3696a60..948ec596c 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskManager.java @@ -12,7 +12,7 @@ import io.a2a.spec.A2AServerException; import io.a2a.spec.Event; -import io.a2a.spec.InvalidParamsError; +import io.a2a.spec.InternalError; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskArtifactUpdateEvent; @@ -59,12 +59,13 @@ public TaskManager(@Nullable String taskId, @Nullable String contextId, TaskStor return currentTask; } - Task saveTaskEvent(Task task) throws A2AServerException { + boolean saveTaskEvent(Task task, boolean isReplicated) throws A2AServerException { checkIdsAndUpdateIfNecessary(task.id(), task.contextId()); - return saveTask(task); + Task savedTask = saveTask(task, isReplicated); + return savedTask.status() != null && savedTask.status().state() != null && savedTask.status().state().isFinal(); } - Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException { + boolean saveTaskEvent(TaskStatusUpdateEvent event, boolean isReplicated) throws A2AServerException { checkIdsAndUpdateIfNecessary(event.taskId(), event.contextId()); Task task = ensureTask(event.taskId(), event.contextId()); @@ -86,10 +87,11 @@ Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException { } task = builder.build(); - return saveTask(task); + Task savedTask = saveTask(task, isReplicated); + return savedTask.status() != null && savedTask.status().state() != null && savedTask.status().state().isFinal(); } - Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { + boolean saveTaskEvent(TaskArtifactUpdateEvent event, boolean isReplicated) throws A2AServerException { checkIdsAndUpdateIfNecessary(event.taskId(), event.contextId()); Task task = ensureTask(event.taskId(), event.contextId()); // taskId is guaranteed to be non-null after checkIdsAndUpdateIfNecessary @@ -98,18 +100,20 @@ Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException { throw new IllegalStateException("taskId should not be null after checkIdsAndUpdateIfNecessary"); } task = appendArtifactToTask(task, event, nonNullTaskId); - return saveTask(task); + Task savedTask = saveTask(task, isReplicated); + return savedTask.status() != null && savedTask.status().state() != null && savedTask.status().state().isFinal(); } - public Event process(Event event) throws A2AServerException { + public boolean process(Event event, boolean isReplicated) throws A2AServerException { + boolean isFinal = false; if (event instanceof Task task) { - saveTaskEvent(task); + isFinal = saveTaskEvent(task, isReplicated); } else if (event instanceof TaskStatusUpdateEvent taskStatusUpdateEvent) { - saveTaskEvent(taskStatusUpdateEvent); + isFinal = saveTaskEvent(taskStatusUpdateEvent, isReplicated); } else if (event instanceof TaskArtifactUpdateEvent taskArtifactUpdateEvent) { - saveTaskEvent(taskArtifactUpdateEvent); + isFinal = saveTaskEvent(taskArtifactUpdateEvent, isReplicated); } - return event; + return isFinal; } public Task updateWithMessage(Message message, Task task) { @@ -125,7 +129,7 @@ public Task updateWithMessage(Message message, Task task) { .status(status) .history(history) .build(); - saveTask(task); + saveTask(task, false); // Local operation, not replicated return task; } @@ -133,7 +137,7 @@ private void checkIdsAndUpdateIfNecessary(String eventTaskId, String eventContex if (taskId != null && !eventTaskId.equals(taskId)) { throw new A2AServerException( "Invalid task id", - new InvalidParamsError(String.format("Task in event doesn't match TaskManager "))); + new InternalError(String.format("Task event has taskId %s but TaskManager has %s", eventTaskId, taskId))); } if (taskId == null) { taskId = eventTaskId; @@ -155,7 +159,7 @@ private Task ensureTask(String eventTaskId, String eventContextId) { } if (task == null) { task = createTask(eventTaskId, eventContextId); - saveTask(task); + saveTask(task, false); // Local operation, not replicated } return task; } @@ -170,8 +174,8 @@ private Task createTask(String taskId, String contextId) { .build(); } - private Task saveTask(Task task) { - taskStore.save(task); + private Task saveTask(Task task, boolean isReplicated) { + taskStore.save(task, isReplicated); if (taskId == null) { taskId = task.id(); contextId = task.contextId(); diff --git a/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java b/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java index 18707fba2..3df903f77 100644 --- a/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java +++ b/server-common/src/main/java/io/a2a/server/tasks/TaskStore.java @@ -100,8 +100,11 @@ public interface TaskStore { * Saves or updates a task. * * @param task the task to save + * @param isReplicated true if this task update came from a replicated event, + * false if it originated locally. Used to prevent feedback loops + * in replicated scenarios (e.g., don't fire TaskFinalizedEvent for replicated updates) */ - void save(Task task); + void save(Task task, boolean isReplicated); /** * Retrieves a task by its ID. diff --git a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java index e26dd55fb..eee254ba3 100644 --- a/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java +++ b/server-common/src/main/java/io/a2a/server/util/async/AsyncExecutorProducer.java @@ -1,8 +1,8 @@ package io.a2a.server.util.async; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ThreadFactory; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -26,6 +26,7 @@ public class AsyncExecutorProducer { private static final String A2A_EXECUTOR_CORE_POOL_SIZE = "a2a.executor.core-pool-size"; private static final String A2A_EXECUTOR_MAX_POOL_SIZE = "a2a.executor.max-pool-size"; private static final String A2A_EXECUTOR_KEEP_ALIVE_SECONDS = "a2a.executor.keep-alive-seconds"; + private static final String A2A_EXECUTOR_QUEUE_CAPACITY = "a2a.executor.queue-capacity"; @Inject A2AConfigProvider configProvider; @@ -57,6 +58,16 @@ public class AsyncExecutorProducer { */ long keepAliveSeconds; + /** + * Queue capacity for pending tasks. + *

    + * Property: {@code a2a.executor.queue-capacity}
    + * Default: 100
    + * Note: Must be bounded to allow pool growth to maxPoolSize. + * When queue is full, new threads are created up to maxPoolSize. + */ + int queueCapacity; + private @Nullable ExecutorService executor; @PostConstruct @@ -64,18 +75,34 @@ public void init() { corePoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_CORE_POOL_SIZE)); maxPoolSize = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_MAX_POOL_SIZE)); keepAliveSeconds = Long.parseLong(configProvider.getValue(A2A_EXECUTOR_KEEP_ALIVE_SECONDS)); - - LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}", - corePoolSize, maxPoolSize, keepAliveSeconds); - - executor = new ThreadPoolExecutor( + queueCapacity = Integer.parseInt(configProvider.getValue(A2A_EXECUTOR_QUEUE_CAPACITY)); + + LOGGER.info("Initializing async executor: corePoolSize={}, maxPoolSize={}, keepAliveSeconds={}, queueCapacity={}", + corePoolSize, maxPoolSize, keepAliveSeconds, queueCapacity); + + // CRITICAL: Use ArrayBlockingQueue (bounded) instead of LinkedBlockingQueue (unbounded). + // With unbounded queue, ThreadPoolExecutor NEVER grows beyond corePoolSize because the + // queue never fills. This causes executor pool exhaustion during concurrent requests when + // EventConsumer polling threads hold all core threads and agent tasks queue indefinitely. + // Bounded queue enables pool growth: when queue is full, new threads are created up to + // maxPoolSize, preventing agent execution starvation. + ThreadPoolExecutor tpe = new ThreadPoolExecutor( corePoolSize, maxPoolSize, keepAliveSeconds, TimeUnit.SECONDS, - new LinkedBlockingQueue<>(), + new ArrayBlockingQueue<>(queueCapacity), new A2AThreadFactory() ); + + // CRITICAL: Allow core threads to timeout after keepAliveSeconds when idle. + // By default, ThreadPoolExecutor only times out threads above corePoolSize. + // Without this, core threads accumulate during testing and never clean up. + // This is essential for streaming scenarios where many short-lived tasks create threads + // for agent execution and cleanup callbacks, but those threads remain idle afterward. + tpe.allowCoreThreadTimeOut(true); + + executor = tpe; } @PreDestroy @@ -106,6 +133,22 @@ public Executor produce() { return executor; } + /** + * Log current executor pool statistics for diagnostics. + * Useful for debugging pool exhaustion or sizing issues. + */ + public void logPoolStats() { + if (executor instanceof ThreadPoolExecutor tpe) { + LOGGER.info("Executor pool stats: active={}/{}, queued={}/{}, completed={}, total={}", + tpe.getActiveCount(), + tpe.getPoolSize(), + tpe.getQueue().size(), + queueCapacity, + tpe.getCompletedTaskCount(), + tpe.getTaskCount()); + } + } + private static class A2AThreadFactory implements ThreadFactory { private final AtomicInteger threadNumber = new AtomicInteger(1); private final String namePrefix = "a2a-agent-executor-"; diff --git a/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java new file mode 100644 index 000000000..24ff7f5d1 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/async/EventConsumerExecutorProducer.java @@ -0,0 +1,93 @@ +package io.a2a.server.util.async; + +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Produces; +import jakarta.inject.Qualifier; + +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.annotation.ElementType.*; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Produces a dedicated executor for EventConsumer polling threads. + *

    + * CRITICAL: EventConsumer polling must use a separate executor from AgentExecutor because: + *

      + *
    • EventConsumer threads are I/O-bound (blocking on queue.poll()), not CPU-bound
    • + *
    • One EventConsumer thread needed per active queue (can be 100+ concurrent)
    • + *
    • Threads are mostly idle, waiting for events
    • + *
    • Using the same bounded pool as AgentExecutor causes deadlock when pool exhausted
    • + *
    + *

    + * Uses a cached thread pool (unbounded) with automatic thread reclamation: + *

      + *
    • Creates threads on demand as EventConsumers start
    • + *
    • Idle threads automatically terminated after 10 seconds
    • + *
    • No queue saturation since threads are created as needed
    • + *
    + */ +@ApplicationScoped +public class EventConsumerExecutorProducer { + private static final Logger LOGGER = LoggerFactory.getLogger(EventConsumerExecutorProducer.class); + + /** + * Qualifier annotation for EventConsumer executor injection. + */ + @Retention(RUNTIME) + @Target({METHOD, FIELD, PARAMETER, TYPE}) + @Qualifier + public @interface EventConsumerExecutor { + } + + /** + * Thread factory for EventConsumer threads. + */ + private static class EventConsumerThreadFactory implements ThreadFactory { + private final AtomicInteger threadNumber = new AtomicInteger(1); + + @Override + public Thread newThread(Runnable r) { + Thread thread = new Thread(r, "a2a-event-consumer-" + threadNumber.getAndIncrement()); + thread.setDaemon(true); + return thread; + } + } + + private @Nullable ExecutorService executor; + + @Produces + @EventConsumerExecutor + @ApplicationScoped + public Executor eventConsumerExecutor() { + // Cached thread pool with 10s idle timeout (reduced from default 60s): + // - Creates threads on demand as EventConsumers start + // - Reclaims idle threads after 10s to prevent accumulation during fast test execution + // - Perfect for I/O-bound EventConsumer polling which blocks on queue.poll() + // - 10s timeout balances thread reuse (production) vs cleanup (testing) + executor = new ThreadPoolExecutor( + 0, // corePoolSize - no core threads + Integer.MAX_VALUE, // maxPoolSize - unbounded + 10, TimeUnit.SECONDS, // keepAliveTime - 10s idle timeout + new SynchronousQueue<>(), // queue - same as cached pool + new EventConsumerThreadFactory() + ); + + LOGGER.info("Initialized EventConsumer executor: cached thread pool (unbounded, 10s idle timeout)"); + + return executor; + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java new file mode 100644 index 000000000..737fbac23 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/SseFormatter.java @@ -0,0 +1,136 @@ +package io.a2a.server.util.sse; + +import io.a2a.grpc.utils.JSONRPCUtils; +import io.a2a.jsonrpc.common.wrappers.A2AErrorResponse; +import io.a2a.jsonrpc.common.wrappers.A2AResponse; +import io.a2a.jsonrpc.common.wrappers.CancelTaskResponse; +import io.a2a.jsonrpc.common.wrappers.DeleteTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetExtendedAgentCardResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; +import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; +import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; +import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; + +/** + * Framework-agnostic utility for formatting A2A responses as Server-Sent Events (SSE). + *

    + * Provides static methods to serialize A2A responses to JSON and format them as SSE events. + * This allows HTTP server frameworks (Vert.x, Jakarta/WildFly, etc.) to use their own + * reactive libraries for publisher mapping while sharing the serialization logic. + *

    + * Example usage (Quarkus/Vert.x with Mutiny): + *

    {@code
    + * Flow.Publisher> responses = handler.onMessageSendStream(request, context);
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Multi sseEvents = Multi.createFrom().publisher(responses)
    + *     .map(response -> SseFormatter.formatResponseAsSSE(response, eventId.getAndIncrement()));
    + *
    + * sseEvents.subscribe().with(sseEvent -> httpResponse.write(Buffer.buffer(sseEvent)));
    + * }
    + *

    + * Example usage (Jakarta/WildFly with custom reactive library): + *

    {@code
    + * Flow.Publisher jsonStrings = restHandler.getJsonPublisher();
    + * AtomicLong eventId = new AtomicLong(0);
    + *
    + * Flow.Publisher sseEvents = mapPublisher(jsonStrings,
    + *     json -> SseFormatter.formatJsonAsSSE(json, eventId.getAndIncrement()));
    + * }
    + */ +public class SseFormatter { + + private SseFormatter() { + // Utility class - prevent instantiation + } + + /** + * Format an A2A response as an SSE event. + *

    + * Serializes the response to JSON and formats as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + * + * @param response the A2A response to format + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatResponseAsSSE(A2AResponse response, long eventId) { + String jsonData = serializeResponse(response); + return "data: " + jsonData + "\nid: " + eventId + "\n\n"; + } + + /** + * Format a pre-serialized JSON string as an SSE event. + *

    + * Wraps the JSON in SSE format as: + *

    +     * data: {"jsonrpc":"2.0","result":{...},"id":123}
    +     * id: 0
    +     *
    +     * 
    + *

    + * Use this when you already have JSON strings (e.g., from REST transport) + * and just need to add SSE formatting. + * + * @param jsonString the JSON string to wrap + * @param eventId the SSE event ID + * @return SSE-formatted string (ready to write to HTTP response) + */ + public static String formatJsonAsSSE(String jsonString, long eventId) { + return "data: " + jsonString + "\nid: " + eventId + "\n\n"; + } + + /** + * Serialize an A2AResponse to JSON string. + */ + private static String serializeResponse(A2AResponse response) { + // For error responses, use standard JSON-RPC error format + if (response instanceof A2AErrorResponse error) { + return JSONRPCUtils.toJsonRPCErrorResponse(error.getId(), error.getError()); + } + if (response.getError() != null) { + return JSONRPCUtils.toJsonRPCErrorResponse(response.getId(), response.getError()); + } + + // Convert domain response to protobuf message and serialize + com.google.protobuf.MessageOrBuilder protoMessage = convertToProto(response); + return JSONRPCUtils.toJsonRPCResultResponse(response.getId(), protoMessage); + } + + /** + * Convert A2AResponse to protobuf message for serialization. + */ + private static com.google.protobuf.MessageOrBuilder convertToProto(A2AResponse response) { + if (response instanceof GetTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof CancelTaskResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.task(r.getResult()); + } else if (response instanceof SendMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessage(r.getResult()); + } else if (response instanceof ListTasksResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTasksResult(r.getResult()); + } else if (response instanceof SetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.setTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof GetTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof ListTaskPushNotificationConfigResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.listTaskPushNotificationConfigResponse(r.getResult()); + } else if (response instanceof DeleteTaskPushNotificationConfigResponse) { + // DeleteTaskPushNotificationConfig has no result body, just return empty message + return com.google.protobuf.Empty.getDefaultInstance(); + } else if (response instanceof GetExtendedAgentCardResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.getExtendedCardResponse(r.getResult()); + } else if (response instanceof SendStreamingMessageResponse r) { + return io.a2a.grpc.utils.ProtoUtils.ToProto.taskOrMessageStream(r.getResult()); + } else { + throw new IllegalArgumentException("Unknown response type: " + response.getClass().getName()); + } + } +} diff --git a/server-common/src/main/java/io/a2a/server/util/sse/package-info.java b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java new file mode 100644 index 000000000..7e668b632 --- /dev/null +++ b/server-common/src/main/java/io/a2a/server/util/sse/package-info.java @@ -0,0 +1,11 @@ +/** + * Server-Sent Events (SSE) formatting utilities for A2A streaming responses. + *

    + * Provides framework-agnostic conversion of {@code Flow.Publisher>} to + * {@code Flow.Publisher} with SSE formatting, enabling easy integration with + * any HTTP server framework (Vert.x, Jakarta Servlet, etc.). + */ +@NullMarked +package io.a2a.server.util.sse; + +import org.jspecify.annotations.NullMarked; diff --git a/server-common/src/main/resources/META-INF/a2a-defaults.properties b/server-common/src/main/resources/META-INF/a2a-defaults.properties index 280fd943b..719be9e7a 100644 --- a/server-common/src/main/resources/META-INF/a2a-defaults.properties +++ b/server-common/src/main/resources/META-INF/a2a-defaults.properties @@ -19,3 +19,7 @@ a2a.executor.max-pool-size=50 # Keep-alive time for idle threads (seconds) a2a.executor.keep-alive-seconds=60 + +# Queue capacity for pending tasks (must be bounded to enable pool growth) +# When queue is full, new threads are created up to max-pool-size +a2a.executor.queue-capacity=100 diff --git a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java index 4354f1639..146bfb10a 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventConsumerTest.java @@ -16,6 +16,8 @@ import java.util.concurrent.atomic.AtomicReference; import io.a2a.jsonrpc.common.json.JsonProcessingException; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.A2AServerException; import io.a2a.spec.Artifact; @@ -27,14 +29,19 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventConsumerTest { + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id + private EventQueue eventQueue; private EventConsumer eventConsumer; - + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private static final String MINIMAL_TASK = """ { @@ -54,10 +61,59 @@ public class EventConsumerTest { @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); eventConsumer = new EventConsumer(eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } + } + @Test public void testConsumeOneTaskEvent() throws Exception { Task event = fromJson(MINIMAL_TASK, Task.class); @@ -92,7 +148,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -100,7 +156,7 @@ public void testConsumeAllMultipleEvents() throws JsonProcessingException { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -128,7 +184,7 @@ public void testConsumeUntilMessage() throws Exception { List events = List.of( fromJson(MINIMAL_TASK, Task.class), TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -136,7 +192,7 @@ public void testConsumeUntilMessage() throws Exception { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -185,14 +241,14 @@ public void testConsumeMessageEvents() throws Exception { @Test public void testConsumeTaskInputRequired() { Task task = Task.builder() - .id("task-id") - .contextId("task-context") + .id(TASK_ID) + .contextId("session-xyz") .status(new TaskStatus(TaskState.INPUT_REQUIRED)) .build(); List events = List.of( task, TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -200,7 +256,7 @@ public void testConsumeTaskInputRequired() { .build()) .build(), TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -332,7 +388,9 @@ public void onComplete() { @Test public void testConsumeAllStopsOnQueueClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Close the queue immediately @@ -378,12 +436,16 @@ public void onComplete() { @Test public void testConsumeAllHandlesQueueClosedException() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Add a message event (which will complete the stream) Event message = fromJson(MESSAGE_PAYLOAD, Message.class); - queue.enqueueEvent(message); + + // Use callback to wait for event processing + waitForEventProcessing(() -> queue.enqueueEvent(message)); // Close the queue before consuming queue.close(); @@ -428,11 +490,13 @@ public void onComplete() { @Test public void testConsumeAllTerminatesOnQueueClosedEvent() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .mainEventBus(mainEventBus) + .build().tap(); EventConsumer consumer = new EventConsumer(queue); // Enqueue a QueueClosedEvent (poison pill) - QueueClosedEvent queueClosedEvent = new QueueClosedEvent("task-123"); + QueueClosedEvent queueClosedEvent = new QueueClosedEvent(TASK_ID); queue.enqueueEvent(queueClosedEvent); Flow.Publisher publisher = consumer.consumeAll(); @@ -477,8 +541,12 @@ public void onComplete() { } private void enqueueAndConsumeOneEvent(Event event) throws Exception { - eventQueue.enqueueEvent(event); + // Use callback to wait for event processing + waitForEventProcessing(() -> eventQueue.enqueueEvent(event)); + + // Event is now available, consume it directly Event result = eventConsumer.consumeOne(); + assertNotNull(result, "Event should be available"); assertSame(event, result); } diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java index a3dc7d916..2499a8173 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueTest.java @@ -11,7 +11,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import io.a2a.server.tasks.InMemoryTaskStore; +import io.a2a.server.tasks.PushNotificationSender; import io.a2a.spec.A2AError; import io.a2a.spec.Artifact; import io.a2a.spec.Event; @@ -23,12 +27,17 @@ import io.a2a.spec.TaskStatus; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; public class EventQueueTest { private EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + + private static final String TASK_ID = "123"; // Must match MINIMAL_TASK id private static final String MINIMAL_TASK = """ { @@ -46,38 +55,96 @@ public class EventQueueTest { } """; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void init() { - eventQueue = EventQueue.builder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); + } + + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + mainEventBusProcessor.setCallback(null); // Clear any test callbacks + EventQueueUtil.stop(mainEventBusProcessor); + } + } + /** + * Helper to create a queue with MainEventBus configured (for tests that need event distribution). + */ + private EventQueue createQueueWithEventBus(String taskId) { + return EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(taskId) + .build(); + } + + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, io.a2a.spec.Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + mainEventBusProcessor.setCallback(null); + } } @Test public void testConstructorDefaultQueueSize() { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); assertEquals(EventQueue.DEFAULT_QUEUE_SIZE, queue.getQueueSize()); } @Test public void testConstructorCustomQueueSize() { int customSize = 500; - EventQueue queue = EventQueue.builder().queueSize(customSize).build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(customSize).build(); assertEquals(customSize, queue.getQueueSize()); } @Test public void testConstructorInvalidQueueSize() { // Test zero queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(0).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(0).build()); // Test negative queue size - assertThrows(IllegalArgumentException.class, () -> EventQueue.builder().queueSize(-10).build()); + assertThrows(IllegalArgumentException.class, () -> EventQueueUtil.getEventQueueBuilder(mainEventBus).queueSize(-10).build()); } @Test public void testTapCreatesChildQueue() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertNotNull(childQueue); @@ -87,7 +154,7 @@ public void testTapCreatesChildQueue() { @Test public void testTapOnChildQueueThrowsException() { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue childQueue = parentQueue.tap(); assertThrows(IllegalStateException.class, () -> childQueue.tap()); @@ -95,69 +162,74 @@ public void testTapOnChildQueueThrowsException() { @Test public void testEnqueueEventPropagagesToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Event should be available in both parent and child queues - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - Event childEvent = childQueue.dequeueEventItem(-1).getEvent(); + // Event should be available in all child queues + // Note: MainEventBusProcessor runs async, so we use dequeueEventItem with timeout + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); - assertSame(event, parentEvent); - assertSame(event, childEvent); + assertSame(event, child1Event); + assertSame(event, child2Event); } @Test public void testMultipleChildQueuesReceiveEvents() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event1 = fromJson(MINIMAL_TASK, Task.class); Event event2 = fromJson(MESSAGE_PAYLOAD, Message.class); - parentQueue.enqueueEvent(event1); - parentQueue.enqueueEvent(event2); + mainQueue.enqueueEvent(event1); + mainQueue.enqueueEvent(event2); - // All queues should receive both events - assertSame(event1, parentQueue.dequeueEventItem(-1).getEvent()); - assertSame(event2, parentQueue.dequeueEventItem(-1).getEvent()); + // All child queues should receive both events + // Note: Use timeout for async processing + assertSame(event1, childQueue1.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue1.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue1.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue1.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue2.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue2.dequeueEventItem(5000).getEvent()); - assertSame(event1, childQueue2.dequeueEventItem(-1).getEvent()); - assertSame(event2, childQueue2.dequeueEventItem(-1).getEvent()); + assertSame(event1, childQueue3.dequeueEventItem(5000).getEvent()); + assertSame(event2, childQueue3.dequeueEventItem(5000).getEvent()); } @Test public void testChildQueueDequeueIndependently() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); - EventQueue childQueue1 = parentQueue.tap(); - EventQueue childQueue2 = parentQueue.tap(); + EventQueue mainQueue = createQueueWithEventBus(TASK_ID); + EventQueue childQueue1 = mainQueue.tap(); + EventQueue childQueue2 = mainQueue.tap(); + EventQueue childQueue3 = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - parentQueue.enqueueEvent(event); + mainQueue.enqueueEvent(event); - // Dequeue from child1 first - Event child1Event = childQueue1.dequeueEventItem(-1).getEvent(); + // Dequeue from child1 first (use timeout for async processing) + Event child1Event = childQueue1.dequeueEventItem(5000).getEvent(); assertSame(event, child1Event); // child2 should still have the event available - Event child2Event = childQueue2.dequeueEventItem(-1).getEvent(); + Event child2Event = childQueue2.dequeueEventItem(5000).getEvent(); assertSame(event, child2Event); - // Parent should still have the event available - Event parentEvent = parentQueue.dequeueEventItem(-1).getEvent(); - assertSame(event, parentEvent); + // child3 should still have the event available + Event child3Event = childQueue3.dequeueEventItem(5000).getEvent(); + assertSame(event, child3Event); } @Test public void testCloseImmediatePropagationToChildren() throws Exception { - EventQueue parentQueue = EventQueue.builder().build(); + EventQueue parentQueue = createQueueWithEventBus(TASK_ID); EventQueue childQueue = parentQueue.tap(); // Add events to both parent and child @@ -166,7 +238,7 @@ public void testCloseImmediatePropagationToChildren() throws Exception { assertFalse(childQueue.isClosed()); try { - assertNotNull(childQueue.dequeueEventItem(-1)); // Child has the event + assertNotNull(childQueue.dequeueEventItem(5000)); // Child has the event (use timeout) } catch (EventQueueClosedException e) { // This is fine if queue closed before dequeue } @@ -187,27 +259,37 @@ public void testCloseImmediatePropagationToChildren() throws Exception { @Test public void testEnqueueEventWhenClosed() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.close(); // Close the queue first - assertTrue(queue.isClosed()); + childQueue.close(); // Close the child queue first (removes from children list) + assertTrue(childQueue.isClosed()); + + // Create a new child queue BEFORE enqueuing (ensures it's in children list for distribution) + EventQueue newChildQueue = mainQueue.tap(); // MainQueue accepts events even when closed (for replication support) // This ensures late-arriving replicated events can be enqueued to closed queues - queue.enqueueEvent(event); + // Note: MainEventBusProcessor runs asynchronously, so we use dequeueEventItem with timeout + mainQueue.enqueueEvent(event); - // Event should be available for dequeuing - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // New child queue should receive the event (old closed child was removed from children list) + EventQueueItem item = newChildQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); - // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + // Now new child queue is closed and empty, should throw exception + newChildQueue.close(); + assertThrows(EventQueueClosedException.class, () -> newChildQueue.dequeueEventItem(-1)); } @Test public void testDequeueEventWhenClosedAndEmpty() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build().tap(); queue.close(); assertTrue(queue.isClosed()); @@ -217,19 +299,27 @@ public void testDequeueEventWhenClosedAndEmpty() throws Exception { @Test public void testDequeueEventWhenClosedButHasEvents() throws Exception { - EventQueue queue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TASK_ID) + .build(); + EventQueue childQueue = mainQueue.tap(); Event event = fromJson(MINIMAL_TASK, Task.class); - queue.enqueueEvent(event); - queue.close(); // Graceful close - events should remain - assertTrue(queue.isClosed()); + // Use callback to wait for event processing instead of polling + waitForEventProcessing(() -> mainQueue.enqueueEvent(event)); - // Should still be able to dequeue existing events - Event dequeuedEvent = queue.dequeueEventItem(-1).getEvent(); + // At this point, event has been processed and distributed to childQueue + childQueue.close(); // Graceful close - events should remain + assertTrue(childQueue.isClosed()); + + // Should still be able to dequeue existing events from closed queue + EventQueueItem item = childQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); // Now queue is closed and empty, should throw exception - assertThrows(EventQueueClosedException.class, () -> queue.dequeueEventItem(-1)); + assertThrows(EventQueueClosedException.class, () -> childQueue.dequeueEventItem(-1)); } @Test @@ -244,7 +334,9 @@ public void testEnqueueAndDequeueEvent() throws Exception { public void testDequeueEventNoWait() throws Exception { Event event = fromJson(MINIMAL_TASK, Task.class); eventQueue.enqueueEvent(event); - Event dequeuedEvent = eventQueue.dequeueEventItem(-1).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event dequeuedEvent = item.getEvent(); assertSame(event, dequeuedEvent); } @@ -257,7 +349,7 @@ public void testDequeueEventEmptyQueueNoWait() throws Exception { @Test public void testDequeueEventWait() throws Exception { Event event = TaskStatusUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .status(new TaskStatus(TaskState.WORKING)) .isFinal(true) @@ -271,7 +363,7 @@ public void testDequeueEventWait() throws Exception { @Test public void testTaskDone() throws Exception { Event event = TaskArtifactUpdateEvent.builder() - .taskId("task-123") + .taskId(TASK_ID) .contextId("session-xyz") .artifact(Artifact.builder() .artifactId("11") @@ -347,7 +439,7 @@ public void testCloseIdempotent() throws Exception { assertTrue(eventQueue.isClosed()); // Test with immediate close as well - EventQueue eventQueue2 = EventQueue.builder().build(); + EventQueue eventQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); eventQueue2.close(true); assertTrue(eventQueue2.isClosed()); @@ -361,19 +453,20 @@ public void testCloseIdempotent() throws Exception { */ @Test public void testCloseChildQueues() throws Exception { - EventQueue childQueue = eventQueue.tap(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue = mainQueue.tap(); assertTrue(childQueue != null); // Graceful close - parent closes but children remain open - eventQueue.close(); - assertTrue(eventQueue.isClosed()); + mainQueue.close(); + assertTrue(mainQueue.isClosed()); assertFalse(childQueue.isClosed()); // Child NOT closed on graceful parent close // Immediate close - parent force-closes all children - EventQueue parentQueue2 = EventQueue.builder().build(); - EventQueue childQueue2 = parentQueue2.tap(); - parentQueue2.close(true); // immediate=true - assertTrue(parentQueue2.isClosed()); + EventQueue mainQueue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue childQueue2 = mainQueue2.tap(); + mainQueue2.close(true); // immediate=true + assertTrue(mainQueue2.isClosed()); assertTrue(childQueue2.isClosed()); // Child IS closed on immediate parent close } @@ -383,7 +476,7 @@ public void testCloseChildQueues() throws Exception { */ @Test public void testMainQueueReferenceCountingStaysOpenWithActiveChildren() throws Exception { - EventQueue mainQueue = EventQueue.builder().build(); + EventQueue mainQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); EventQueue child1 = mainQueue.tap(); EventQueue child2 = mainQueue.tap(); diff --git a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java index 39201c1f6..6c9ed4a17 100644 --- a/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java +++ b/server-common/src/test/java/io/a2a/server/events/EventQueueUtil.java @@ -1,8 +1,39 @@ package io.a2a.server.events; +import java.util.concurrent.atomic.AtomicInteger; + public class EventQueueUtil { - // Since EventQueue.builder() is package protected, add a method to expose it - public static EventQueue.EventQueueBuilder getEventQueueBuilder() { - return EventQueue.builder(); + // Counter for generating unique test taskIds + private static final AtomicInteger TASK_ID_COUNTER = new AtomicInteger(0); + + /** + * Get an EventQueue builder pre-configured with the shared test MainEventBus and a unique taskId. + *

    + * Note: Returns MainQueue - tests should call .tap() if they need to consume events. + *

    + * + * @return builder with TEST_EVENT_BUS and unique taskId already set + */ + public static EventQueue.EventQueueBuilder getEventQueueBuilder(MainEventBus eventBus) { + return EventQueue.builder(eventBus) + .taskId("test-task-" + TASK_ID_COUNTER.incrementAndGet()); + } + + /** + * Start a MainEventBusProcessor instance. + * + * @param processor the processor to start + */ + public static void start(MainEventBusProcessor processor) { + processor.start(); + } + + /** + * Stop a MainEventBusProcessor instance. + * + * @param processor the processor to stop + */ + public static void stop(MainEventBusProcessor processor) { + processor.stop(); } } diff --git a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java index 1eca1b739..3e09ff2af 100644 --- a/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/events/InMemoryQueueManagerTest.java @@ -14,7 +14,10 @@ import java.util.concurrent.ExecutionException; import java.util.stream.IntStream; +import io.a2a.server.tasks.InMemoryTaskStore; import io.a2a.server.tasks.MockTaskStateProvider; +import io.a2a.server.tasks.PushNotificationSender; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -22,17 +25,30 @@ public class InMemoryQueueManagerTest { private InMemoryQueueManager queueManager; private MockTaskStateProvider taskStateProvider; + private InMemoryTaskStore taskStore; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; @BeforeEach public void setUp() { taskStateProvider = new MockTaskStateProvider(); - queueManager = new InMemoryQueueManager(taskStateProvider); + taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(taskStateProvider, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + } + + @AfterEach + public void tearDown() { + EventQueueUtil.stop(mainEventBusProcessor); } @Test public void testAddNewQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); @@ -43,8 +59,8 @@ public void testAddNewQueue() { @Test public void testAddExistingQueueThrowsException() { String taskId = "test_task_id"; - EventQueue queue1 = EventQueue.builder().build(); - EventQueue queue2 = EventQueue.builder().build(); + EventQueue queue1 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); + EventQueue queue2 = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue1); @@ -56,7 +72,7 @@ public void testAddExistingQueueThrowsException() { @Test public void testGetExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue result = queueManager.get(taskId); @@ -73,7 +89,7 @@ public void testGetNonexistentQueue() { @Test public void testTapExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); EventQueue tappedQueue = queueManager.tap(taskId); @@ -94,7 +110,7 @@ public void testTapNonexistentQueue() { @Test public void testCloseExistingQueue() { String taskId = "test_task_id"; - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); queueManager.close(taskId); @@ -129,7 +145,7 @@ public void testCreateOrTapNewQueue() { @Test public void testCreateOrTapExistingQueue() { String taskId = "test_task_id"; - EventQueue originalQueue = EventQueue.builder().build(); + EventQueue originalQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, originalQueue); EventQueue result = queueManager.createOrTap(taskId); @@ -151,7 +167,7 @@ public void testConcurrentOperations() throws InterruptedException, ExecutionExc // Add tasks concurrently List> addFutures = taskIds.stream() .map(taskId -> CompletableFuture.supplyAsync(() -> { - EventQueue queue = EventQueue.builder().build(); + EventQueue queue = EventQueueUtil.getEventQueueBuilder(mainEventBus).build(); queueManager.add(taskId, queue); return taskId; })) diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java index ea5bbe797..9c64f03f9 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/AbstractA2ARequestHandlerTest.java @@ -26,7 +26,10 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; import io.a2a.server.events.EventQueueItem; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.server.tasks.BasePushNotificationSender; import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; import io.a2a.server.tasks.InMemoryTaskStore; @@ -66,6 +69,8 @@ public class AbstractA2ARequestHandlerTest { private static final String PREFERRED_TRANSPORT = "preferred-transport"; private static final String A2A_REQUESTHANDLER_TEST_PROPERTIES = "/a2a-requesthandler-test.properties"; + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + protected AgentExecutor executor; protected TaskStore taskStore; protected RequestHandler requestHandler; @@ -73,6 +78,8 @@ public class AbstractA2ARequestHandlerTest { protected AgentExecutorMethod agentExecutorCancel; protected InMemoryQueueManager queueManager; protected TestHttpClient httpClient; + protected MainEventBus mainEventBus; + protected MainEventBusProcessor mainEventBusProcessor; protected final Executor internalExecutor = Executors.newCachedThreadPool(); @@ -96,19 +103,31 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws A2AErro InMemoryTaskStore inMemoryTaskStore = new InMemoryTaskStore(); taskStore = inMemoryTaskStore; - queueManager = new InMemoryQueueManager(inMemoryTaskStore); + + // Create push notification components BEFORE MainEventBusProcessor httpClient = new TestHttpClient(); PushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); PushNotificationSender pushSender = new BasePushNotificationSender(pushConfigStore, httpClient); + // Create MainEventBus and MainEventBusProcessor (production code path) + mainEventBus = new MainEventBus(); + queueManager = new InMemoryQueueManager(inMemoryTaskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, pushSender, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, pushConfigStore, pushSender, internalExecutor); + executor, taskStore, queueManager, pushConfigStore, mainEventBusProcessor, internalExecutor, internalExecutor); } @AfterEach public void cleanup() { agentExecutorExecute = null; agentExecutorCancel = null; + + // Stop MainEventBusProcessor background thread + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } } protected static AgentCard createAgentCard(boolean streaming, boolean pushNotifications, boolean stateTransitionHistory) { diff --git a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java index 293babe4e..e69de29bb 100644 --- a/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java +++ b/server-common/src/test/java/io/a2a/server/requesthandlers/DefaultRequestHandlerTest.java @@ -1,1001 +0,0 @@ -package io.a2a.server.requesthandlers; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; - -import io.a2a.server.ServerCallContext; -import io.a2a.server.agentexecution.AgentExecutor; -import io.a2a.server.agentexecution.RequestContext; -import io.a2a.server.auth.UnauthenticatedUser; -import io.a2a.server.events.EventQueue; -import io.a2a.server.events.InMemoryQueueManager; -import io.a2a.server.tasks.InMemoryPushNotificationConfigStore; -import io.a2a.server.tasks.InMemoryTaskStore; -import io.a2a.server.tasks.TaskUpdater; -import io.a2a.spec.A2AError; -import io.a2a.spec.ListTaskPushNotificationConfigParams; -import io.a2a.spec.ListTaskPushNotificationConfigResult; -import io.a2a.spec.Message; -import io.a2a.spec.MessageSendConfiguration; -import io.a2a.spec.MessageSendParams; -import io.a2a.spec.PushNotificationConfig; -import io.a2a.spec.Task; -import io.a2a.spec.TaskState; -import io.a2a.spec.TaskStatus; -import io.a2a.spec.TextPart; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; - -/** - * Comprehensive tests for DefaultRequestHandler, backported from Python's - * test_default_request_handler.py. These tests cover core functionality that - * is transport-agnostic and should work across JSON-RPC, gRPC, and REST. - * - * Background cleanup and task tracking tests are from Python PR #440 and #472. - */ -public class DefaultRequestHandlerTest { - - private DefaultRequestHandler requestHandler; - private InMemoryTaskStore taskStore; - private InMemoryQueueManager queueManager; - private TestAgentExecutor agentExecutor; - private ServerCallContext serverCallContext; - - @BeforeEach - void setUp() { - taskStore = new InMemoryTaskStore(); - // Pass taskStore as TaskStateProvider to queueManager for task-aware queue management - queueManager = new InMemoryQueueManager(taskStore); - agentExecutor = new TestAgentExecutor(); - - requestHandler = DefaultRequestHandler.create( - agentExecutor, - taskStore, - queueManager, - null, // pushConfigStore - null, // pushSender - Executors.newCachedThreadPool() - ); - - serverCallContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of(), Set.of()); - } - - /** - * Test that multiple blocking messages to the same task work correctly - * when agent doesn't emit final events (fire-and-forget pattern). - * This replicates TCK test: test_message_send_continue_task - */ - @Test - @Timeout(10) - void testBlockingMessageContinueTask() throws Exception { - String taskId = "continue-task-1"; - String contextId = "continue-ctx-1"; - - // Configure agent to NOT complete tasks (like TCK fire-and-forget agent) - agentExecutor.setExecuteCallback((context, queue) -> { - Task task = context.getTask(); - if (task == null) { - // First message: create SUBMITTED task - task = Task.builder() - .id(context.getTaskId()) - .contextId(context.getContextId()) - .status(new TaskStatus(TaskState.SUBMITTED)) - .build(); - } else { - // Subsequent messages: emit WORKING task (non-final) - task = Task.builder() - .id(context.getTaskId()) - .contextId(context.getContextId()) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - } - queue.enqueueEvent(task); - // Don't complete - just return (fire-and-forget) - }); - - // First blocking message - should return SUBMITTED task - Message message1 = Message.builder() - .messageId("msg-1") - .role(Message.Role.USER) - .parts(new TextPart("first message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params1 = new MessageSendParams(message1, null, null, ""); - Object result1 = requestHandler.onMessageSend(params1, serverCallContext); - - assertTrue(result1 instanceof Task); - Task task1 = (Task) result1; - assertTrue(task1.id().equals(taskId)); - assertTrue(task1.status().state() == TaskState.SUBMITTED); - - // Second blocking message to SAME taskId - should not hang - Message message2 = Message.builder() - .messageId("msg-2") - .role(Message.Role.USER) - .parts(new TextPart("second message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params2 = new MessageSendParams(message2, null, null, ""); - Object result2 = requestHandler.onMessageSend(params2, serverCallContext); - - // Should complete successfully (not timeout) - assertTrue(result2 instanceof Task); - } - - /** - * Test that background cleanup tasks are properly tracked and cleared. - * Backported from Python test: test_background_cleanup_task_is_tracked_and_cleared - */ - @Test - @Timeout(10) - void testBackgroundCleanupTaskIsTrackedAndCleared() throws Exception { - String taskId = "track-task-1"; - String contextId = "track-ctx-1"; - - // Create a task that will trigger background cleanup - Task task = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.SUBMITTED)) - .build(); - - taskStore.save(task); - - Message message = Message.builder() - .messageId("msg-track") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Set up agent to finish quickly so cleanup runs - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowAgentFinish = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - try { - allowAgentFinish.await(5, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - - // Start streaming (this will create background tasks) - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Allow agent to finish, which should trigger cleanup - allowAgentFinish.countDown(); - - // Give some time for background tasks to be tracked and cleaned up - Thread.sleep(1000); - - // Background tasks should eventually be cleared - // Note: We can't directly access the backgroundTasks field without reflection, - // but the test verifies the mechanism doesn't hang or leak tasks - assertTrue(true, "Background cleanup completed without hanging"); - } - - /** - * Test that client disconnect triggers background cleanup and producer continues. - * Backported from Python test: test_on_message_send_stream_client_disconnect_triggers_background_cleanup_and_producer_continues - */ - @Test - @Timeout(10) - void testStreamingClientDisconnectTriggersBackgroundCleanup() throws Exception { - String taskId = "disc-task-1"; - String contextId = "disc-ctx-1"; - - Message message = Message.builder() - .messageId("mid") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Agent should start and then wait - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowAgentFinish = new CountDownLatch(1); - AtomicBoolean agentCompleted = new AtomicBoolean(false); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - try { - allowAgentFinish.await(10, TimeUnit.SECONDS); - agentCompleted.set(true); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - - // Start streaming - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start executing"); - - // Simulate client disconnect by not consuming the stream - // In real scenarios, the client would close the connection - - // Agent should still be running (not finished immediately on "disconnect") - Thread.sleep(500); - assertTrue(agentExecutor.isExecuting(), "Producer should still be running after simulated disconnect"); - - // Allow agent to finish - allowAgentFinish.countDown(); - - // Wait a bit for completion - Thread.sleep(1000); - - assertTrue(agentCompleted.get(), "Agent should have completed execution"); - } - - /** - * Test that resubscription works after client disconnect. - * Backported from Python test: test_stream_disconnect_then_resubscribe_receives_future_events - */ - @Test - @Timeout(15) - void testStreamDisconnectThenResubscribeReceivesFutureEvents() throws Exception { - String taskId = "reconn-task-1"; - String contextId = "reconn-ctx-1"; - - // Create initial task - Task initialTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - - taskStore.save(initialTask); - - Message message = Message.builder() - .messageId("msg-reconn") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Set up agent to emit events with controlled timing - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowSecondEvent = new CountDownLatch(1); - CountDownLatch allowFinish = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit first event - Task firstEvent = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(firstEvent); - - // Wait for permission to emit second event - try { - allowSecondEvent.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Emit second event - Task secondEvent = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - queue.enqueueEvent(secondEvent); - - try { - allowFinish.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }); - - // Start streaming and simulate getting first event then disconnecting - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start and emit first event - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Simulate client disconnect (in real scenario, client would close connection) - // The background cleanup should keep the producer running - - // Now try to resubscribe to the task - io.a2a.spec.TaskIdParams resubParams = new io.a2a.spec.TaskIdParams(taskId); - - // Allow agent to emit second event - allowSecondEvent.countDown(); - - // Try resubscription - this should work because queue is still alive - var resubResult = requestHandler.onResubscribeToTask(resubParams, serverCallContext); - // If we get here without exception, resubscription worked - assertTrue(true, "Resubscription succeeded"); - - // Clean up - allowFinish.countDown(); - } - - /** - * Test that task state is persisted to task store after client disconnect. - * Backported from Python test: test_disconnect_persists_final_task_to_store - */ - @Test - @Timeout(15) - void testDisconnectPersistsFinalTaskToStore() throws Exception { - String taskId = "persist-task-1"; - String contextId = "persist-ctx-1"; - - Message message = Message.builder() - .messageId("msg-persist") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Agent that completes after some delay - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowCompletion = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit working status - Task workingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(workingTask); - - try { - allowCompletion.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Emit final completed status - Task completedTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - queue.enqueueEvent(completedTask); - }); - - // Start streaming and simulate client disconnect - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Simulate client disconnect by not consuming the stream further - // In real scenarios, the reactive stream would be cancelled - - // Allow agent to complete in background - allowCompletion.countDown(); - - // Give time for background processing to persist the final state - Thread.sleep(2000); - - // Verify the final task state was persisted despite client disconnect - Task persistedTask = taskStore.get(taskId); - if (persistedTask != null) { - // If task was persisted, it should have the final state - assertTrue( - persistedTask.status().state() == TaskState.COMPLETED || - persistedTask.status().state() == TaskState.WORKING, - "Task should be persisted with working or completed state, got: " + persistedTask.status().state() - ); - } - // Note: In some architectures, the task might not be persisted if the - // background consumption isn't implemented. This test documents the expected behavior. - } - - /** - * Test that blocking message call waits for agent to finish and returns complete Task - * even when agent does fire-and-forget (emits non-final state and returns). - * - * Expected behavior: - * 1. Agent emits WORKING state with artifacts - * 2. Agent's execute() method returns WITHOUT emitting final state - * 3. Blocking onMessageSend() should wait for agent execution to complete - * 4. Blocking onMessageSend() should wait for all queued events to be processed - * 5. Returned Task should have WORKING state with all artifacts included - * - * This tests fire-and-forget pattern with blocking calls. - */ - @Test - @Timeout(15) - void testBlockingFireAndForgetReturnsNonFinalTask() throws Exception { - String taskId = "blocking-fire-forget-task"; - String contextId = "blocking-fire-forget-ctx"; - - Message message = Message.builder() - .messageId("msg-blocking-fire-forget") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent that does fire-and-forget: emits WORKING with artifact but never completes - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - - // Start work (WORKING state) - updater.startWork(); - - // Add artifact - updater.addArtifact( - List.of(new TextPart("Fire and forget artifact")), - "artifact-1", "FireForget", null); - - // Agent returns WITHOUT calling updater.complete() - // Task stays in WORKING state (non-final) - }); - - // Call blocking onMessageSend - should wait for agent to finish - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // The returned result should be a Task in WORKING state with artifact - assertTrue(result instanceof Task, "Result should be a Task"); - Task returnedTask = (Task) result; - - // Verify task is in WORKING state (non-final, fire-and-forget) - assertEquals(TaskState.WORKING, returnedTask.status().state(), - "Returned task should be WORKING (fire-and-forget), got: " + returnedTask.status().state()); - - // Verify artifacts are included in the returned task - assertNotNull(returnedTask.artifacts(), - "Returned task should have artifacts"); - assertTrue(returnedTask.artifacts().size() >= 1, - "Returned task should have at least 1 artifact, got: " + - returnedTask.artifacts().size()); - } - - /** - * Test that non-blocking message call returns immediately and persists all events in background. - * - * Expected behavior: - * 1. Non-blocking call returns immediately with first event (WORKING state) - * 2. Agent continues running in background and produces more events - * 3. Background consumption continues and persists all events to TaskStore - * 4. Final task state (COMPLETED) is persisted in background - */ - @Test - @Timeout(15) - void testNonBlockingMessagePersistsAllEventsInBackground() throws Exception { - String taskId = "blocking-persist-task"; - String contextId = "blocking-persist-ctx"; - - Message message = Message.builder() - .messageId("msg-nonblocking-persist") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - // Use default non-blocking behavior - MessageSendConfiguration config = MessageSendConfiguration.builder() - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent that produces multiple events with delays - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch firstEventEmitted = new CountDownLatch(1); - CountDownLatch allowCompletion = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit first event (WORKING state) - Task workingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(workingTask); - firstEventEmitted.countDown(); - - // Sleep to ensure the non-blocking call has returned before we emit more events - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Wait for permission to complete - try { - allowCompletion.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return; - } - - // Emit final event (COMPLETED state) - // This event should be persisted to TaskStore in background - Task completedTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - queue.enqueueEvent(completedTask); - }); - - // Call non-blocking onMessageSend - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // Assertion 1: The immediate result should be the first event (WORKING) - assertTrue(result instanceof Task, "Result should be a Task"); - Task immediateTask = (Task) result; - assertEquals(TaskState.WORKING, immediateTask.status().state(), - "Non-blocking should return immediately with WORKING state, got: " + immediateTask.status().state()); - - // At this point, the non-blocking call has returned, but the agent is still running - - // Allow the agent to emit the final COMPLETED event - allowCompletion.countDown(); - - // Assertion 2: Poll for the final task state to be persisted in background - // Use polling loop instead of fixed sleep for faster and more reliable test - long timeoutMs = 5000; - long startTime = System.currentTimeMillis(); - Task persistedTask = null; - boolean completedStateFound = false; - - while (System.currentTimeMillis() - startTime < timeoutMs) { - persistedTask = taskStore.get(taskId); - if (persistedTask != null && persistedTask.status().state() == TaskState.COMPLETED) { - completedStateFound = true; - break; - } - Thread.sleep(100); // Poll every 100ms - } - - assertTrue(persistedTask != null, "Task should be persisted to store"); - assertTrue( - completedStateFound, - "Final task state should be COMPLETED (background consumption should have processed it), got: " + - (persistedTask != null ? persistedTask.status().state() : "null") + - " after " + (System.currentTimeMillis() - startTime) + "ms" - ); - } - - /** - * Test the BIG idea: MainQueue stays open for non-final tasks even when all children close. - * This enables fire-and-forget tasks and late resubscriptions. - */ - @Test - @Timeout(15) - void testMainQueueStaysOpenForNonFinalTasks() throws Exception { - String taskId = "fire-and-forget-task"; - String contextId = "fire-ctx"; - - // Create initial task in WORKING state (non-final) - Task initialTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - taskStore.save(initialTask); - - Message message = Message.builder() - .messageId("msg-fire") - .role(Message.Role.USER) - .parts(new TextPart("fire and forget")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendParams params = new MessageSendParams(message, null, null, ""); - - // Agent that emits WORKING status but never completes (fire-and-forget pattern) - CountDownLatch agentStarted = new CountDownLatch(1); - CountDownLatch allowAgentFinish = new CountDownLatch(1); - - agentExecutor.setExecuteCallback((context, queue) -> { - agentStarted.countDown(); - - // Emit WORKING status (non-final) - Task workingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - queue.enqueueEvent(workingTask); - - // Don't emit final state - just wait and finish - try { - allowAgentFinish.await(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - // Agent finishes WITHOUT emitting final task state - }); - - // Start streaming - var streamingResult = requestHandler.onMessageSendStream(params, serverCallContext); - - // Wait for agent to start and emit WORKING event - assertTrue(agentStarted.await(5, TimeUnit.SECONDS), "Agent should start"); - - // Give time for WORKING event to be processed - Thread.sleep(500); - - // Simulate client disconnect - this closes the ChildQueue - // but MainQueue should stay open because task is non-final - - // Allow agent to finish - allowAgentFinish.countDown(); - - // Give time for agent to finish and cleanup to run - Thread.sleep(2000); - - // THE BIG IDEA TEST: Resubscription should work because MainQueue is still open - // Even though: - // 1. The original ChildQueue closed (client disconnected) - // 2. The agent finished executing - // 3. Task is still in non-final WORKING state - // Therefore: MainQueue should still be open for resubscriptions - - io.a2a.spec.TaskIdParams resubParams = new io.a2a.spec.TaskIdParams(taskId); - var resubResult = requestHandler.onResubscribeToTask(resubParams, serverCallContext); - - // If we get here without exception, the BIG idea works! - assertTrue(true, "Resubscription succeeded - MainQueue stayed open for non-final task"); - } - - /** - * Test that MainQueue DOES close when task is finalized. - * This ensures Level 2 protection doesn't prevent cleanup of completed tasks. - */ - @Test - @Timeout(15) - void testMainQueueClosesForFinalizedTasks() throws Exception { - String taskId = "completed-task"; - String contextId = "completed-ctx"; - - // Create initial task in COMPLETED state (already finalized) - Task completedTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - taskStore.save(completedTask); - - // Create a queue for this task - EventQueue mainQueue = queueManager.createOrTap(taskId); - assertTrue(mainQueue != null, "Queue should be created"); - - // Close the child queue (simulating client disconnect) - mainQueue.close(); - - // Give time for cleanup callback to run - Thread.sleep(1000); - - // Since the task is finalized (COMPLETED), the MainQueue should be removed from the map - // This tests that Level 2 protection (childClosing check) allows cleanup for finalized tasks - EventQueue queue = queueManager.get(taskId); - assertTrue(queue == null || queue.isClosed(), - "Queue for finalized task should be null or closed"); - } - - /** - * Test that blocking message call returns a Task with ALL artifacts included. - * This reproduces the reported bug: blocking call returns before artifacts are processed. - * - * Expected behavior: - * 1. Agent emits multiple artifacts via TaskUpdater - * 2. Blocking onMessageSend() should wait for ALL events to be processed - * 3. Returned Task should have all artifacts included in COMPLETED state - * - * Bug manifestation: - * - onMessageSend() returns after first event - * - Artifacts are still being processed in background - * - Returned Task is incomplete - */ - @Test - @Timeout(15) - void testBlockingCallReturnsCompleteTaskWithArtifacts() throws Exception { - String taskId = "blocking-artifacts-task"; - String contextId = "blocking-artifacts-ctx"; - - Message message = Message.builder() - .messageId("msg-blocking-artifacts") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent that uses TaskUpdater to emit multiple artifacts (like real agents do) - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - - // Start work (WORKING state) - updater.startWork(); - - // Add first artifact - updater.addArtifact( - List.of(new TextPart("First artifact")), - "artifact-1", "First", null); - - // Add second artifact - updater.addArtifact( - List.of(new TextPart("Second artifact")), - "artifact-2", "Second", null); - - // Complete the task - updater.complete(); - }); - - // Call blocking onMessageSend - should wait for ALL events - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // The returned result should be a Task with ALL artifacts - assertTrue(result instanceof Task, "Result should be a Task"); - Task returnedTask = (Task) result; - - // Verify task is completed - assertEquals(TaskState.COMPLETED, returnedTask.status().state(), - "Returned task should be COMPLETED"); - - // Verify artifacts are included in the returned task - assertNotNull(returnedTask.artifacts(), - "Returned task should have artifacts"); - assertTrue(returnedTask.artifacts().size() >= 2, - "Returned task should have at least 2 artifacts, got: " + - returnedTask.artifacts().size()); - } - - /** - * Test that pushNotificationConfig from SendMessageConfiguration is stored for NEW tasks - * in non-streaming (blocking) mode. This reproduces the bug from issue #84. - * - * Expected behavior: - * 1. Client sends message with pushNotificationConfig in SendMessageConfiguration - * 2. Agent creates a new task - * 3. pushNotificationConfig should be stored in PushNotificationConfigStore - * 4. Config should be retrievable via getInfo() - */ - @Test - @Timeout(10) - void testBlockingMessageStoresPushNotificationConfigForNewTask() throws Exception { - String taskId = "push-config-blocking-new-task"; - String contextId = "push-config-ctx"; - - // Create test config store - InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); - - // Re-create request handler with pushConfigStore - requestHandler = DefaultRequestHandler.create( - agentExecutor, - taskStore, - queueManager, - pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() - ); - - // Create push notification config - PushNotificationConfig pushConfig = PushNotificationConfig.builder() - .id("config-1") - .url("https://example.com/webhook") - .token("test-token-123") - .build(); - - // Create message with pushNotificationConfig - Message message = Message.builder() - .messageId("msg-push-config") - .role(Message.Role.USER) - .parts(new TextPart("test message")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .pushNotificationConfig(pushConfig) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent creates a new task - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - updater.submit(); // Creates new task in SUBMITTED state - updater.complete(); - }); - - // Call blocking onMessageSend - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // Verify result is a task - assertTrue(result instanceof Task, "Result should be a Task"); - Task returnedTask = (Task) result; - assertEquals(taskId, returnedTask.id()); - - // THE KEY ASSERTION: Verify pushNotificationConfig was stored - ListTaskPushNotificationConfigResult storedConfigs = pushConfigStore.getInfo(new ListTaskPushNotificationConfigParams(taskId)); - assertNotNull(storedConfigs, "Push notification config should be stored for new task"); - assertEquals(1, storedConfigs.size(), - "Should have exactly 1 push config stored"); - PushNotificationConfig storedConfig = storedConfigs.configs().get(0).pushNotificationConfig(); - assertEquals("config-1", storedConfig.id()); - assertEquals("https://example.com/webhook", storedConfig.url()); - } - - /** - * Test that pushNotificationConfig is stored for EXISTING tasks. - * This verifies the initMessageSend logic works correctly. - */ - @Test - @Timeout(10) - void testBlockingMessageStoresPushNotificationConfigForExistingTask() throws Exception { - String taskId = "push-config-existing-task"; - String contextId = "push-config-existing-ctx"; - - // Create test config store - InMemoryPushNotificationConfigStore pushConfigStore = new InMemoryPushNotificationConfigStore(); - - // Re-create request handler with pushConfigStore - requestHandler = DefaultRequestHandler.create( - agentExecutor, - taskStore, - queueManager, - pushConfigStore, // Add push config store - null, // pushSender - Executors.newCachedThreadPool() - ); - - // Create EXISTING task in store - Task existingTask = Task.builder() - .id(taskId) - .contextId(contextId) - .status(new TaskStatus(TaskState.WORKING)) - .build(); - taskStore.save(existingTask); - - // Create push notification config - PushNotificationConfig pushConfig = PushNotificationConfig.builder() - .id("config-existing-1") - .url("https://example.com/existing-webhook") - .token("existing-token-789") - .build(); - - Message message = Message.builder() - .messageId("msg-push-existing") - .role(Message.Role.USER) - .parts(new TextPart("update existing task")) - .taskId(taskId) - .contextId(contextId) - .build(); - - MessageSendConfiguration config = MessageSendConfiguration.builder() - .blocking(true) - .pushNotificationConfig(pushConfig) - .build(); - - MessageSendParams params = new MessageSendParams(message, config, null, ""); - - // Agent updates the existing task - agentExecutor.setExecuteCallback((context, queue) -> { - TaskUpdater updater = new TaskUpdater(context, queue); - updater.addArtifact( - List.of(new TextPart("update artifact")), - "artifact-1", "Update", null); - updater.complete(); - }); - - // Call blocking onMessageSend - Object result = requestHandler.onMessageSend(params, serverCallContext); - - // Verify result - assertTrue(result instanceof Task, "Result should be a Task"); - - // Verify pushNotificationConfig was stored (initMessageSend path) - ListTaskPushNotificationConfigResult storedConfigs = pushConfigStore.getInfo(new ListTaskPushNotificationConfigParams(taskId)); - assertNotNull(storedConfigs,"Push notification config should be stored for existing task"); - assertEquals(1, storedConfigs.size(),"Should have exactly 1 push config stored"); - PushNotificationConfig storedConfig = storedConfigs.configs().get(0).pushNotificationConfig(); - assertEquals("config-existing-1", storedConfig.id()); - assertEquals("https://example.com/existing-webhook", storedConfig.url()); - } - - /** - * Simple test agent executor that allows controlling execution timing - */ - private static class TestAgentExecutor implements AgentExecutor { - private ExecuteCallback executeCallback; - private volatile boolean executing = false; - - interface ExecuteCallback { - void call(RequestContext context, EventQueue queue) throws A2AError; - } - - void setExecuteCallback(ExecuteCallback callback) { - this.executeCallback = callback; - } - - boolean isExecuting() { - return executing; - } - - @Override - public void execute(RequestContext context, EventQueue eventQueue) throws A2AError { - executing = true; - try { - if (executeCallback != null) { - // Custom callback is responsible for emitting events - executeCallback.call(context, eventQueue); - } else { - // No custom callback - emit default completion event - Task completedTask = Task.builder() - .id(context.getTaskId()) - .contextId(context.getContextId()) - .status(new TaskStatus(TaskState.COMPLETED)) - .build(); - eventQueue.enqueueEvent(completedTask); - } - - } finally { - executing = false; - } - } - - @Override - public void cancel(RequestContext context, EventQueue eventQueue) throws A2AError { - // Simple cancel implementation - executing = false; - } - } -} \ No newline at end of file diff --git a/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java b/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java index e814c1c15..e69de29bb 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/InMemoryTaskStoreTest.java @@ -1,49 +0,0 @@ -package io.a2a.server.tasks; - -import static io.a2a.jsonrpc.common.json.JsonUtil.fromJson; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertSame; - -import io.a2a.spec.Task; -import org.junit.jupiter.api.Test; - -public class InMemoryTaskStoreTest { - private static final String TASK_JSON = """ - { - "id": "task-abc", - "contextId" : "session-xyz", - "status": {"state": "submitted"} - }"""; - - @Test - public void testSaveAndGet() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - Task task = fromJson(TASK_JSON, Task.class); - store.save(task); - Task retrieved = store.get(task.id()); - assertSame(task, retrieved); - } - - @Test - public void testGetNonExistent() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - Task retrieved = store.get("nonexistent"); - assertNull(retrieved); - } - - @Test - public void testDelete() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - Task task = fromJson(TASK_JSON, Task.class); - store.save(task); - store.delete(task.id()); - Task retrieved = store.get(task.id()); - assertNull(retrieved); - } - - @Test - public void testDeleteNonExistent() throws Exception { - InMemoryTaskStore store = new InMemoryTaskStore(); - store.delete("non-existent"); - } -} diff --git a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java index d64729077..0e25e9aad 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/ResultAggregatorTest.java @@ -11,18 +11,25 @@ import static org.mockito.Mockito.when; import java.util.Collections; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueUtil; import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; +import io.a2a.spec.Event; import io.a2a.spec.EventKind; import io.a2a.spec.Message; import io.a2a.spec.Task; import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatus; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mock; @@ -49,7 +56,7 @@ public class ResultAggregatorTest { @BeforeEach void setUp() { MockitoAnnotations.openMocks(this); - aggregator = new ResultAggregator(mockTaskManager, null, testExecutor); + aggregator = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); } // Helper methods for creating sample data @@ -69,13 +76,45 @@ private Task createSampleTask(String taskId, TaskState statusState, String conte .build(); } + /** + * Helper to wait for MainEventBusProcessor to process an event. + * Replaces polling patterns with deterministic callback-based waiting. + * + * @param processor the processor to set callback on + * @param action the action that triggers event processing + * @throws InterruptedException if waiting is interrupted + * @throws AssertionError if processing doesn't complete within timeout + */ + private void waitForEventProcessing(MainEventBusProcessor processor, Runnable action) throws InterruptedException { + CountDownLatch processingLatch = new CountDownLatch(1); + processor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { + @Override + public void onEventProcessed(String taskId, Event event) { + processingLatch.countDown(); + } + + @Override + public void onTaskFinalized(String taskId) { + // Not needed for basic event processing wait + } + }); + + try { + action.run(); + assertTrue(processingLatch.await(5, TimeUnit.SECONDS), + "MainEventBusProcessor should have processed the event within timeout"); + } finally { + processor.setCallback(null); + } + } + // Basic functionality tests @Test void testConstructorWithMessage() { Message initialMessage = createSampleMessage("initial", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, initialMessage, testExecutor, testExecutor); // Test that the message is properly stored by checking getCurrentResult assertEquals(initialMessage, aggregatorWithMessage.getCurrentResult()); @@ -86,7 +125,7 @@ void testConstructorWithMessage() { @Test void testGetCurrentResultWithMessageSet() { Message sampleMessage = createSampleMessage("hola", "msg1", Message.Role.USER); - ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor); + ResultAggregator aggregatorWithMessage = new ResultAggregator(mockTaskManager, sampleMessage, testExecutor, testExecutor); EventKind result = aggregatorWithMessage.getCurrentResult(); @@ -121,7 +160,7 @@ void testConstructorStoresTaskManagerCorrectly() { @Test void testConstructorWithNullMessage() { - ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor); + ResultAggregator aggregatorWithNullMessage = new ResultAggregator(mockTaskManager, null, testExecutor, testExecutor); Task expectedTask = createSampleTask("null_msg_task", TaskState.WORKING, "ctx1"); when(mockTaskManager.getTask()).thenReturn(expectedTask); @@ -181,7 +220,7 @@ void testMultipleGetCurrentResultCalls() { void testGetCurrentResultWithMessageTakesPrecedence() { // Test that when both message and task are available, message takes precedence Message message = createSampleMessage("priority message", "pri1", Message.Role.USER); - ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor); + ResultAggregator messageAggregator = new ResultAggregator(mockTaskManager, message, testExecutor, testExecutor); // Even if we set up the task manager to return something, message should take precedence Task task = createSampleTask("should_not_be_returned", TaskState.WORKING, "ctx1"); @@ -197,17 +236,24 @@ void testGetCurrentResultWithMessageTakesPrecedence() { @Test void testConsumeAndBreakNonBlocking() throws Exception { // Test that with blocking=false, the method returns after the first event - Task firstEvent = createSampleTask("non_blocking_task", TaskState.WORKING, "ctx1"); + String taskId = "test-task"; + Task firstEvent = createSampleTask(taskId, TaskState.WORKING, "ctx1"); // After processing firstEvent, the current result will be that task when(mockTaskManager.getTask()).thenReturn(firstEvent); // Create an event queue using QueueManager (which has access to builder) + MainEventBus mainEventBus = new MainEventBus(); + InMemoryTaskStore taskStore = new InMemoryTaskStore(); InMemoryQueueManager queueManager = - new InMemoryQueueManager(new MockTaskStateProvider()); + new InMemoryQueueManager(new MockTaskStateProvider(), mainEventBus); + MainEventBusProcessor processor = new MainEventBusProcessor(mainEventBus, taskStore, task -> {}, queueManager); + EventQueueUtil.start(processor); + + EventQueue queue = queueManager.getEventQueueBuilder(taskId).build().tap(); - EventQueue queue = queueManager.getEventQueueBuilder("test-task").build(); - queue.enqueueEvent(firstEvent); + // Use callback to wait for event processing (replaces polling) + waitForEventProcessing(processor, () -> queue.enqueueEvent(firstEvent)); // Create real EventConsumer with the queue EventConsumer eventConsumer = @@ -221,11 +267,16 @@ void testConsumeAndBreakNonBlocking() throws Exception { assertEquals(firstEvent, result.eventType()); assertTrue(result.interrupted()); - verify(mockTaskManager).process(firstEvent); - // getTask() is called at least once for the return value (line 255) - // May be called once more if debug logging executes in time (line 209) - // The async consumer may or may not execute before verification, so we accept 1-2 calls - verify(mockTaskManager, atLeast(1)).getTask(); - verify(mockTaskManager, atMost(2)).getTask(); + // NOTE: ResultAggregator no longer calls taskManager.process() + // That responsibility has moved to MainEventBusProcessor for centralized persistence + // + // NOTE: Since firstEvent is a Task, ResultAggregator captures it directly from the queue + // (capturedTask.get() at line 283 in ResultAggregator). Therefore, taskManager.getTask() + // is only called for debug logging in taskIdForLogging() (line 305), which may or may not + // execute depending on timing and log level. We expect 0-1 calls, not 1-2. + verify(mockTaskManager, atMost(1)).getTask(); + + // Cleanup: stop the processor + EventQueueUtil.stop(processor); } } diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java index f14ebc0fe..91010e52e 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskManagerTest.java @@ -48,7 +48,7 @@ public void init() throws Exception { @Test public void testGetTaskExisting() { Task expectedTask = minimalTask; - taskStore.save(expectedTask); + taskStore.save(expectedTask, false); Task retrieved = taskManager.getTask(); assertSame(expectedTask, retrieved); } @@ -61,16 +61,16 @@ public void testGetTaskNonExistent() { @Test public void testSaveTaskEventNewTask() throws A2AServerException { - Task saved = taskManager.saveTaskEvent(minimalTask); + taskManager.saveTaskEvent(minimalTask, false); + Task saved = taskManager.getTask(); Task retrieved = taskManager.getTask(); assertSame(minimalTask, retrieved); - assertSame(retrieved, saved); } @Test public void testSaveTaskEventStatusUpdate() throws A2AServerException { Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); TaskStatus newStatus = new TaskStatus( TaskState.WORKING, @@ -88,11 +88,11 @@ public void testSaveTaskEventStatusUpdate() throws A2AServerException { new HashMap<>()); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updated = taskManager.getTask(); assertNotSame(initialTask, updated); - assertSame(updated, saved); assertEquals(initialTask.id(), updated.id()); assertEquals(initialTask.contextId(), updated.contextId()); @@ -114,10 +114,10 @@ public void testSaveTaskEventArtifactUpdate() throws A2AServerException { .contextId(minimalTask.contextId()) .artifact(newArtifact) .build(); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); - assertSame(updatedTask, saved); assertNotSame(initialTask, updatedTask); assertEquals(initialTask.id(), updatedTask.id()); @@ -144,7 +144,8 @@ public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException .isFinal(false) .build(); - Task task = taskManagerWithoutId.saveTaskEvent(event); + taskManagerWithoutId.saveTaskEvent(event, false); + Task task = taskManagerWithoutId.getTask(); assertEquals(event.taskId(), taskManagerWithoutId.getTaskId()); assertEquals(event.contextId(), taskManagerWithoutId.getContextId()); @@ -164,13 +165,13 @@ public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException { .status(new TaskStatus(TaskState.WORKING)) .build(); - Task saved = taskManagerWithoutId.saveTaskEvent(task); + taskManagerWithoutId.saveTaskEvent(task, false); + Task saved = taskManager.getTask(); assertEquals(task.id(), taskManagerWithoutId.getTaskId()); assertEquals(task.contextId(), taskManagerWithoutId.getContextId()); Task retrieved = taskManagerWithoutId.getTask(); assertSame(task, retrieved); - assertSame(retrieved, saved); } @Test @@ -194,7 +195,7 @@ public void testTaskArtifactUpdateEventAppendTrueWithExistingArtifact() throws A Task taskWithArtifact = Task.builder(initialTask) .artifacts(Collections.singletonList(existingArtifact)) .build(); - taskStore.save(taskWithArtifact); + taskStore.save(taskWithArtifact, false); // Test: Append new parts to existing artifact Artifact newArtifact = Artifact.builder() @@ -209,7 +210,8 @@ public void testTaskArtifactUpdateEventAppendTrueWithExistingArtifact() throws A .append(true) .build(); - Task updatedTask = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); Artifact updatedArtifact = updatedTask.artifacts().get(0); @@ -223,7 +225,7 @@ public void testTaskArtifactUpdateEventAppendTrueWithExistingArtifact() throws A public void testTaskArtifactUpdateEventAppendTrueWithoutExistingArtifact() throws A2AServerException { // Setup: Create a task without artifacts Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Test: Try to append to non-existent artifact (should be ignored) Artifact newArtifact = Artifact.builder() @@ -238,7 +240,8 @@ public void testTaskArtifactUpdateEventAppendTrueWithoutExistingArtifact() throw .append(true) .build(); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); // Should have no artifacts since append was ignored @@ -257,7 +260,7 @@ public void testTaskArtifactUpdateEventAppendFalseWithExistingArtifact() throws Task taskWithArtifact = Task.builder(initialTask) .artifacts(Collections.singletonList(existingArtifact)) .build(); - taskStore.save(taskWithArtifact); + taskStore.save(taskWithArtifact, false); // Test: Replace existing artifact (append=false) Artifact newArtifact = Artifact.builder() @@ -272,7 +275,8 @@ public void testTaskArtifactUpdateEventAppendFalseWithExistingArtifact() throws .append(false) .build(); - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); @@ -294,7 +298,7 @@ public void testTaskArtifactUpdateEventAppendNullWithExistingArtifact() throws A Task taskWithArtifact = Task.builder(initialTask) .artifacts(Collections.singletonList(existingArtifact)) .build(); - taskStore.save(taskWithArtifact); + taskStore.save(taskWithArtifact, false); // Test: Replace existing artifact (append=null, defaults to false) Artifact newArtifact = Artifact.builder() @@ -308,7 +312,8 @@ public void testTaskArtifactUpdateEventAppendNullWithExistingArtifact() throws A .artifact(newArtifact) .build(); // append is null - Task saved = taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); @@ -330,7 +335,7 @@ public void testAddingTaskWithDifferentIdFails() { .build(); assertThrows(A2AServerException.class, () -> { - taskManagerWithId.saveTaskEvent(differentTask); + taskManagerWithId.saveTaskEvent(differentTask, false); }); } @@ -347,7 +352,7 @@ public void testAddingTaskWithDifferentIdViaStatusUpdateFails() { .build(); assertThrows(A2AServerException.class, () -> { - taskManagerWithId.saveTaskEvent(event); + taskManagerWithId.saveTaskEvent(event, false); }); } @@ -368,7 +373,7 @@ public void testAddingTaskWithDifferentIdViaArtifactUpdateFails() { .build(); assertThrows(A2AServerException.class, () -> { - taskManagerWithId.saveTaskEvent(event); + taskManagerWithId.saveTaskEvent(event, false); }); } @@ -392,7 +397,8 @@ public void testTaskWithNoMessageUsesInitialMessage() throws A2AServerException .isFinal(false) .build(); - Task saved = taskManagerWithInitialMessage.saveTaskEvent(event); + taskManagerWithInitialMessage.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task retrieved = taskManagerWithInitialMessage.getTask(); // Check that the task has the initial message in its history @@ -429,7 +435,8 @@ public void testTaskWithMessageDoesNotUseInitialMessage() throws A2AServerExcept .isFinal(false) .build(); - Task saved = taskManagerWithInitialMessage.saveTaskEvent(event); + taskManagerWithInitialMessage.saveTaskEvent(event, false); + Task saved = taskManager.getTask(); Task retrieved = taskManagerWithInitialMessage.getTask(); // There should now be a history containing the initialMessage @@ -447,7 +454,7 @@ public void testTaskWithMessageDoesNotUseInitialMessage() throws A2AServerExcept public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException { // Test handling of multiple artifacts with the same artifactId Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Add first artifact Artifact artifact1 = Artifact.builder() @@ -460,7 +467,7 @@ public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException .contextId(minimalTask.contextId()) .artifact(artifact1) .build(); - taskManager.saveTaskEvent(event1); + taskManager.saveTaskEvent(event1, false); // Add second artifact with same artifactId (should replace the first) Artifact artifact2 = Artifact.builder() @@ -473,7 +480,7 @@ public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException .contextId(minimalTask.contextId()) .artifact(artifact2) .build(); - taskManager.saveTaskEvent(event2); + taskManager.saveTaskEvent(event2, false); Task updatedTask = taskManager.getTask(); assertEquals(1, updatedTask.artifacts().size()); @@ -487,7 +494,7 @@ public void testMultipleArtifactsWithSameArtifactId() throws A2AServerException public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerException { // Test handling of multiple artifacts with different artifactIds Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); // Add first artifact Artifact artifact1 = Artifact.builder() @@ -500,7 +507,7 @@ public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerExce .contextId(minimalTask.contextId()) .artifact(artifact1) .build(); - taskManager.saveTaskEvent(event1); + taskManager.saveTaskEvent(event1, false); // Add second artifact with different artifactId (should be added) Artifact artifact2 = Artifact.builder() @@ -513,7 +520,7 @@ public void testMultipleArtifactsWithDifferentArtifactIds() throws A2AServerExce .contextId(minimalTask.contextId()) .artifact(artifact2) .build(); - taskManager.saveTaskEvent(event2); + taskManager.saveTaskEvent(event2, false); Task updatedTask = taskManager.getTask(); assertEquals(2, updatedTask.artifacts().size()); @@ -545,7 +552,7 @@ public void testInvalidTaskIdValidation() { public void testSaveTaskEventMetadataUpdate() throws A2AServerException { // Test that metadata from TaskStatusUpdateEvent gets saved to the task Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); Map newMetadata = new HashMap<>(); newMetadata.put("meta_key_test", "meta_value_test"); @@ -558,7 +565,7 @@ public void testSaveTaskEventMetadataUpdate() throws A2AServerException { .metadata(newMetadata) .build(); - taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); Task updatedTask = taskManager.getTask(); assertEquals(newMetadata, updatedTask.metadata()); @@ -568,7 +575,7 @@ public void testSaveTaskEventMetadataUpdate() throws A2AServerException { public void testSaveTaskEventMetadataUpdateNull() throws A2AServerException { // Test that null metadata in TaskStatusUpdateEvent doesn't affect task Task initialTask = minimalTask; - taskStore.save(initialTask); + taskStore.save(initialTask, false); TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId(minimalTask.id()) @@ -578,7 +585,7 @@ public void testSaveTaskEventMetadataUpdateNull() throws A2AServerException { .metadata(null) .build(); - taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); Task updatedTask = taskManager.getTask(); // Should preserve original task's metadata (which is likely null for minimal task) @@ -594,7 +601,7 @@ public void testSaveTaskEventMetadataMergeExisting() throws A2AServerException { Task taskWithMetadata = Task.builder(minimalTask) .metadata(originalMetadata) .build(); - taskStore.save(taskWithMetadata); + taskStore.save(taskWithMetadata, false); Map newMetadata = new HashMap<>(); newMetadata.put("new_key", "new_value"); @@ -607,7 +614,7 @@ public void testSaveTaskEventMetadataMergeExisting() throws A2AServerException { .metadata(newMetadata) .build(); - taskManager.saveTaskEvent(event); + taskManager.saveTaskEvent(event, false); Task updatedTask = taskManager.getTask(); @@ -634,7 +641,8 @@ public void testCreateTaskWithInitialMessage() throws A2AServerException { .isFinal(false) .build(); - Task savedTask = taskManagerWithMessage.saveTaskEvent(event); + taskManagerWithMessage.saveTaskEvent(event, false); + Task savedTask = taskManagerWithMessage.getTask(); // Verify task was created properly assertNotNull(savedTask); @@ -662,7 +670,8 @@ public void testCreateTaskWithoutInitialMessage() throws A2AServerException { .isFinal(false) .build(); - Task savedTask = taskManagerWithoutMessage.saveTaskEvent(event); + taskManagerWithoutMessage.saveTaskEvent(event, false); + Task savedTask = taskManagerWithoutMessage.getTask(); // Verify task was created properly assertNotNull(savedTask); @@ -685,7 +694,8 @@ public void testSaveTaskInternal() throws A2AServerException { .status(new TaskStatus(TaskState.WORKING)) .build(); - Task savedTask = taskManagerWithoutId.saveTaskEvent(newTask); + taskManagerWithoutId.saveTaskEvent(newTask, false); + Task savedTask = taskManagerWithoutId.getTask(); // Verify internal state was updated assertEquals("test-task-id", taskManagerWithoutId.getTaskId()); @@ -716,7 +726,8 @@ public void testUpdateWithMessage() throws A2AServerException { .isFinal(false) .build(); - Task saved = taskManagerWithInitialMessage.saveTaskEvent(event); + taskManagerWithInitialMessage.saveTaskEvent(event, false); + Task saved = taskManagerWithInitialMessage.getTask(); Message updateMessage = Message.builder() .role(Message.Role.USER) diff --git a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java index 40f763569..73da17824 100644 --- a/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java +++ b/server-common/src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java @@ -14,7 +14,11 @@ import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.events.EventQueue; +import io.a2a.server.events.EventQueueItem; import io.a2a.server.events.EventQueueUtil; +import io.a2a.server.events.InMemoryQueueManager; +import io.a2a.server.events.MainEventBus; +import io.a2a.server.events.MainEventBusProcessor; import io.a2a.spec.Event; import io.a2a.spec.Message; import io.a2a.spec.Part; @@ -22,6 +26,7 @@ import io.a2a.spec.TaskState; import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -38,14 +43,28 @@ public class TaskUpdaterTest { private static final List> SAMPLE_PARTS = List.of(new TextPart("Test message")); + private static final PushNotificationSender NOOP_PUSHNOTIFICATION_SENDER = task -> {}; + EventQueue eventQueue; + private MainEventBus mainEventBus; + private MainEventBusProcessor mainEventBusProcessor; private TaskUpdater taskUpdater; @BeforeEach public void init() { - eventQueue = EventQueueUtil.getEventQueueBuilder().build(); + // Set up MainEventBus and processor for production-like test environment + InMemoryTaskStore taskStore = new InMemoryTaskStore(); + mainEventBus = new MainEventBus(); + InMemoryQueueManager queueManager = new InMemoryQueueManager(taskStore, mainEventBus); + mainEventBusProcessor = new MainEventBusProcessor(mainEventBus, taskStore, NOOP_PUSHNOTIFICATION_SENDER, queueManager); + EventQueueUtil.start(mainEventBusProcessor); + + eventQueue = EventQueueUtil.getEventQueueBuilder(mainEventBus) + .taskId(TEST_TASK_ID) + .mainEventBus(mainEventBus) + .build().tap(); RequestContext context = new RequestContext.Builder() .setTaskId(TEST_TASK_ID) .setContextId(TEST_TASK_CONTEXT_ID) @@ -53,10 +72,19 @@ public void init() { taskUpdater = new TaskUpdater(context, eventQueue); } + @AfterEach + public void cleanup() { + if (mainEventBusProcessor != null) { + EventQueueUtil.stop(mainEventBusProcessor); + } + } + @Test public void testAddArtifactWithCustomIdAndName() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "custom-artifact-id", "Custom Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -239,7 +267,9 @@ public void testNewAgentMessageWithMetadata() throws Exception { @Test public void testAddArtifactWithAppendTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -258,7 +288,9 @@ public void testAddArtifactWithAppendTrue() throws Exception { @Test public void testAddArtifactWithLastChunkTrue() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, null, true); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -273,7 +305,9 @@ public void testAddArtifactWithLastChunkTrue() throws Exception { @Test public void testAddArtifactWithAppendAndLastChunk() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, "artifact-id", "Test Artifact", null, true, false); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -287,7 +321,9 @@ public void testAddArtifactWithAppendAndLastChunk() throws Exception { @Test public void testAddArtifactGeneratesIdWhenNull() throws Exception { taskUpdater.addArtifact(SAMPLE_PARTS, null, "Test Artifact", null); - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskArtifactUpdateEvent.class, event); @@ -383,7 +419,9 @@ public void testConcurrentCompletionAttempts() throws Exception { thread2.join(); // Exactly one event should have been queued - Event event = eventQueue.dequeueEventItem(0).getEvent(); + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -396,7 +434,10 @@ public void testConcurrentCompletionAttempts() throws Exception { } private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, TaskState state, Message statusMessage) throws Exception { - Event event = eventQueue.dequeueEventItem(0).getEvent(); + // Wait up to 5 seconds for event (async MainEventBusProcessor needs time to distribute) + EventQueueItem item = eventQueue.dequeueEventItem(5000); + assertNotNull(item); + Event event = item.getEvent(); assertNotNull(event); assertInstanceOf(TaskStatusUpdateEvent.class, event); @@ -408,6 +449,7 @@ private TaskStatusUpdateEvent checkTaskStatusUpdateEventOnQueue(boolean isFinal, assertEquals(state, tsue.status().state()); assertEquals(statusMessage, tsue.status().message()); + // Check no additional events (still use 0 timeout for this check) assertNull(eventQueue.dequeueEventItem(0)); return tsue; diff --git a/tck/src/main/resources/application.properties b/tck/src/main/resources/application.properties index c68793be4..b23747b00 100644 --- a/tck/src/main/resources/application.properties +++ b/tck/src/main/resources/application.properties @@ -12,6 +12,7 @@ a2a.executor.keep-alive-seconds=60 quarkus.log.category."io.a2a.server.requesthandlers".level=DEBUG quarkus.log.category."io.a2a.server.events".level=DEBUG quarkus.log.category."io.a2a.server.tasks".level=DEBUG +io.a2a.server.diagnostics.ThreadStats.level=DEBUG # Log to file for analysis quarkus.log.file.enable=true diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java index 724b58613..62acc506e 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/AbstractA2AServerTest.java @@ -656,7 +656,20 @@ public void testResubscribeExistingTaskSuccess() throws Exception { AtomicReference errorRef = new AtomicReference<>(); // Create consumer to handle resubscribed events + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); BiConsumer consumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!receivedInitialTask.get()) { + if (event instanceof TaskEvent) { + receivedInitialTask.set(true); + // Don't count down latch for initial Task + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent taskUpdateEvent) { if (taskUpdateEvent.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifactEvent) { artifactUpdateEvent.set(artifactEvent); @@ -755,12 +768,25 @@ public void testResubscribeExistingTaskSuccessWithClientConsumers() throws Excep AtomicReference errorRef = new AtomicReference<>(); // Create consumer to handle resubscribed events + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); AgentCard agentCard = createTestAgentCard(); ClientConfig clientConfig = createClientConfig(true); ClientBuilder clientBuilder = Client .builder(agentCard) .addConsumer((evt, agentCard1) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!receivedInitialTask.get()) { + if (evt instanceof TaskEvent) { + receivedInitialTask.set(true); + // Don't count down latch for initial Task + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + evt.getClass().getSimpleName()); + } + } + + // Process subsequent events if (evt instanceof TaskUpdateEvent taskUpdateEvent) { if (taskUpdateEvent.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifactEvent) { artifactUpdateEvent.set(artifactEvent); @@ -918,8 +944,20 @@ public void testMainQueueReferenceCountingWithMultipleConsumers() throws Excepti AtomicReference firstConsumerEvent = new AtomicReference<>(); AtomicBoolean firstUnexpectedEvent = new AtomicBoolean(false); AtomicReference firstErrorRef = new AtomicReference<>(); + AtomicBoolean firstReceivedInitialTask = new AtomicBoolean(false); BiConsumer firstConsumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!firstReceivedInitialTask.get()) { + if (event instanceof TaskEvent) { + firstReceivedInitialTask.set(true); + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent tue && tue.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifact) { firstConsumerEvent.set(artifact); firstConsumerLatch.countDown(); @@ -975,8 +1013,20 @@ public void testMainQueueReferenceCountingWithMultipleConsumers() throws Excepti AtomicReference secondConsumerEvent = new AtomicReference<>(); AtomicBoolean secondUnexpectedEvent = new AtomicBoolean(false); AtomicReference secondErrorRef = new AtomicReference<>(); + AtomicBoolean secondReceivedInitialTask = new AtomicBoolean(false); BiConsumer secondConsumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!secondReceivedInitialTask.get()) { + if (event instanceof TaskEvent) { + secondReceivedInitialTask.set(true); + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent tue && tue.getUpdateEvent() instanceof TaskArtifactUpdateEvent artifact) { secondConsumerEvent.set(artifact); secondConsumerLatch.countDown(); @@ -1316,8 +1366,20 @@ public void testNonBlockingWithMultipleMessages() throws Exception { List resubReceivedEvents = new CopyOnWriteArrayList<>(); AtomicBoolean resubUnexpectedEvent = new AtomicBoolean(false); AtomicReference resubErrorRef = new AtomicReference<>(); + AtomicBoolean resubReceivedInitialTask = new AtomicBoolean(false); BiConsumer resubConsumer = (event, agentCard) -> { + // Per A2A spec 3.1.6: ENFORCE that first event is TaskEvent + if (!resubReceivedInitialTask.get()) { + if (event instanceof TaskEvent) { + resubReceivedInitialTask.set(true); + return; + } else { + fail("First event on resubscribe MUST be TaskEvent, but was: " + event.getClass().getSimpleName()); + } + } + + // Process subsequent events if (event instanceof TaskUpdateEvent tue) { resubReceivedEvents.add(tue.getUpdateEvent()); resubEventLatch.countDown(); @@ -1355,6 +1417,7 @@ public void testNonBlockingWithMultipleMessages() throws Exception { AtomicBoolean streamUnexpectedEvent = new AtomicBoolean(false); BiConsumer streamConsumer = (event, agentCard) -> { + // This consumer is for sendMessage() (not resubscribe), so it doesn't get initial TaskEvent if (event instanceof TaskUpdateEvent tue) { streamReceivedEvents.add(tue.getUpdateEvent()); streamEventLatch.countDown(); diff --git a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java index 9df23c565..45483f214 100644 --- a/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java +++ b/tests/server-common/src/test/java/io/a2a/server/apps/common/TestUtilsBean.java @@ -31,7 +31,7 @@ public class TestUtilsBean { PushNotificationConfigStore pushNotificationConfigStore; public void saveTask(Task task) { - taskStore.save(task); + taskStore.save(task, false); } public Task getTask(String taskId) { diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index 408205aa2..439d97497 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -242,7 +242,7 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request, A2AExtensions.validateRequiredExtensions(getAgentCardInternal(), context); MessageSendParams params = FromProto.messageSendParams(request); Flow.Publisher publisher = getRequestHandler().onMessageSendStream(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -264,7 +264,7 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, ServerCallContext context = createCallContext(responseObserver); TaskIdParams params = FromProto.taskIdParams(request); Flow.Publisher publisher = getRequestHandler().onResubscribeToTask(params, context); - convertToStreamResponse(publisher, responseObserver); + convertToStreamResponse(publisher, responseObserver, context); } catch (A2AError e) { handleError(responseObserver, e); } catch (SecurityException e) { @@ -275,7 +275,8 @@ public void subscribeToTask(io.a2a.grpc.SubscribeToTaskRequest request, } private void convertToStreamResponse(Flow.Publisher publisher, - StreamObserver responseObserver) { + StreamObserver responseObserver, + ServerCallContext context) { CompletableFuture.runAsync(() -> { publisher.subscribe(new Flow.Subscriber() { private Flow.Subscription subscription; @@ -285,6 +286,18 @@ public void onSubscribe(Flow.Subscription subscription) { this.subscription = subscription; subscription.request(1); + // Detect gRPC client disconnect and call EventConsumer.cancel() directly + // This stops the polling loop without relying on subscription cancellation propagation + Context grpcContext = Context.current(); + grpcContext.addListener(new Context.CancellationListener() { + @Override + public void cancelled(Context ctx) { + LOGGER.fine(() -> "gRPC call cancelled by client, calling EventConsumer.cancel() to stop polling loop"); + context.invokeEventConsumerCancelCallback(); + subscription.cancel(); + } + }, getExecutor()); + // Notify tests that we are subscribed Runnable runnable = streamingSubscribedRunnable; if (runnable != null) { @@ -305,6 +318,8 @@ public void onNext(StreamingEventKind event) { @Override public void onError(Throwable throwable) { + // Cancel upstream to stop EventConsumer when error occurs + subscription.cancel(); if (throwable instanceof A2AError jsonrpcError) { handleError(responseObserver, jsonrpcError); } else { @@ -329,6 +344,9 @@ public void getExtendedAgentCard(io.a2a.grpc.GetExtendedAgentCardRequest request if (extendedAgentCard != null) { responseObserver.onNext(ToProto.agentCard(extendedAgentCard)); responseObserver.onCompleted(); + } else { + // Extended agent card not configured - return error instead of hanging + handleError(responseObserver, new ExtendedAgentCardNotConfiguredError(null, "Extended agent card not configured", null)); } } catch (Throwable t) { handleInternalError(responseObserver, t); diff --git a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java index 690d69a87..afed1329f 100644 --- a/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java +++ b/transport/grpc/src/test/java/io/a2a/transport/grpc/handler/GrpcHandlerTest.java @@ -84,7 +84,7 @@ public class GrpcHandlerTest extends AbstractA2ARequestHandlerTest { @Test public void testOnGetTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); GetTaskRequest request = GetTaskRequest.newBuilder() .setName("tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()) .build(); @@ -120,7 +120,7 @@ public void testOnGetTaskNotFound() throws Exception { @Test public void testOnCancelTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { // We need to cancel the task or the EventConsumer never finds a 'final' event. @@ -151,7 +151,7 @@ public void testOnCancelTaskSuccess() throws Exception { @Test public void testOnCancelTaskNotSupported() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { throw new UnsupportedOperationError(); @@ -199,7 +199,7 @@ public void testOnMessageNewMessageSuccess() throws Exception { @Test public void testOnMessageNewMessageWithExistingTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getMessage()); }; @@ -281,8 +281,7 @@ public void testPushNotificationsNotSupportedError() throws Exception { @Test public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = - new DefaultRequestHandler(executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -293,8 +292,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() throws Exception { @Test public void testOnSetPushNotificationNoPushNotifierConfig() throws Exception { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(false, true, false); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); @@ -330,7 +328,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception io.a2a.spec.Task task = io.a2a.spec.Task.builder(AbstractA2ARequestHandlerTest.MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); List results = new ArrayList<>(); List errors = new ArrayList<>(); @@ -379,7 +377,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() throws Excep io.a2a.spec.Task task = io.a2a.spec.Task.builder(AbstractA2ARequestHandlerTest.MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); // This is used to send events from a mock List events = List.of( @@ -424,9 +422,14 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() throws Excep @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - List events = List.of( - AbstractA2ARequestHandlerTest.MINIMAL_TASK, + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); + List events = List.of( + AbstractA2ARequestHandlerTest.MINIMAL_TASK, TaskArtifactUpdateEvent.builder() .taskId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()) .contextId(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId()) @@ -493,13 +496,16 @@ public void onCompleted() { Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); - curr = httpClient.tasks.get(2); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); - Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); - Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); - Assertions.assertEquals(1, curr.artifacts().size()); - Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); - Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + curr = httpClient.tasks.get(2); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(), curr.id()); + Assertions.assertEquals(AbstractA2ARequestHandlerTest.MINIMAL_TASK.contextId(), curr.contextId()); + Assertions.assertEquals(io.a2a.spec.TaskState.COMPLETED, curr.status().state()); + Assertions.assertEquals(1, curr.artifacts().size()); + Assertions.assertEquals(1, curr.artifacts().get(0).parts().size()); + Assertions.assertEquals("text", ((TextPart)curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test @@ -517,7 +523,7 @@ public void testOnResubscribeNoExistingTaskError() throws Exception { @Test public void testOnResubscribeExistingTaskSuccess() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); queueManager.createOrTap(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()); agentExecutorExecute = (context, eventQueue) -> { @@ -542,9 +548,18 @@ public void testOnResubscribeExistingTaskSuccess() throws Exception { streamRecorder.awaitCompletion(5, TimeUnit.SECONDS); List result = streamRecorder.getValues(); Assertions.assertNotNull(result); - Assertions.assertEquals(1, result.size()); - StreamResponse response = result.get(0); - Assertions.assertTrue(response.hasMessage()); + // Per A2A Protocol Spec 3.1.6, resubscribe sends current Task as first event, + // followed by the Message from the agent executor + Assertions.assertEquals(2, result.size()); + + // ENFORCE that first event is Task + Assertions.assertTrue(result.get(0).hasTask(), + "First event on resubscribe MUST be Task (current state)"); + + // Second event should be Message from agent executor + StreamResponse response = result.get(1); + Assertions.assertTrue(response.hasMessage(), + "Expected Message after initial Task"); assertEquals(GRPC_MESSAGE, response.getMessage()); Assertions.assertNull(streamRecorder.getError()); } @@ -552,7 +567,7 @@ public void testOnResubscribeExistingTaskSuccess() throws Exception { @Test public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); queueManager.createOrTap(AbstractA2ARequestHandlerTest.MINIMAL_TASK.id()); List events = List.of( @@ -627,7 +642,7 @@ public void testOnMessageStreamInternalError() throws Exception { @Test public void testListPushNotificationConfig() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -653,7 +668,7 @@ public void testListPushNotificationConfig() throws Exception { public void testListPushNotificationConfigNotSupported() throws Exception { AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(true, false, true); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -668,10 +683,9 @@ public void testListPushNotificationConfigNotSupported() throws Exception { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -702,7 +716,7 @@ public void testListPushNotificationConfigTaskNotFound() { @Test public void testDeletePushNotificationConfig() throws Exception { GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -725,7 +739,7 @@ public void testDeletePushNotificationConfig() throws Exception { public void testDeletePushNotificationConfigNotSupported() throws Exception { AgentCard card = AbstractA2ARequestHandlerTest.createAgentCard(true, false, true); GrpcHandler handler = new TestGrpcHandler(card, requestHandler, internalExecutor); - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -741,8 +755,7 @@ public void testDeletePushNotificationConfigNotSupported() throws Exception { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); GrpcHandler handler = new TestGrpcHandler(AbstractA2ARequestHandlerTest.CARD, requestHandler, internalExecutor); String NAME = "tasks/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id() + "/pushNotificationConfigs/" + AbstractA2ARequestHandlerTest.MINIMAL_TASK.id(); DeleteTaskPushNotificationConfigRequest request = DeleteTaskPushNotificationConfigRequest.newBuilder() @@ -1155,7 +1168,7 @@ private StreamRecorder sendMessageRequest(GrpcHandler handl } private StreamRecorder createTaskPushNotificationConfigRequest(GrpcHandler handler, String name) throws Exception { - taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK); + taskStore.save(AbstractA2ARequestHandlerTest.MINIMAL_TASK, false); PushNotificationConfig config = PushNotificationConfig.newBuilder() .setUrl("http://example.com") .setId("config456") diff --git a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java index b43c28029..4dc151626 100644 --- a/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java +++ b/transport/jsonrpc/src/test/java/io/a2a/transport/jsonrpc/handler/JSONRPCHandlerTest.java @@ -3,6 +3,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import java.util.ArrayList; import java.util.Collections; @@ -30,6 +32,8 @@ import io.a2a.jsonrpc.common.wrappers.GetTaskResponse; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.ListTaskPushNotificationConfigResponse; +import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; +import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.ListTasksResult; import io.a2a.jsonrpc.common.wrappers.SendMessageRequest; import io.a2a.jsonrpc.common.wrappers.SendMessageResponse; @@ -37,12 +41,11 @@ import io.a2a.jsonrpc.common.wrappers.SendStreamingMessageResponse; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigRequest; import io.a2a.jsonrpc.common.wrappers.SetTaskPushNotificationConfigResponse; -import io.a2a.jsonrpc.common.wrappers.ListTasksRequest; -import io.a2a.jsonrpc.common.wrappers.ListTasksResponse; import io.a2a.jsonrpc.common.wrappers.SubscribeToTaskRequest; import io.a2a.server.ServerCallContext; import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventConsumer; +import io.a2a.server.events.MainEventBusProcessorCallback; import io.a2a.server.requesthandlers.AbstractA2ARequestHandlerTest; import io.a2a.server.requesthandlers.DefaultRequestHandler; import io.a2a.server.tasks.ResultAggregator; @@ -52,16 +55,15 @@ import io.a2a.spec.AgentExtension; import io.a2a.spec.AgentInterface; import io.a2a.spec.Artifact; -import io.a2a.spec.ExtendedAgentCardNotConfiguredError; -import io.a2a.spec.ExtensionSupportRequiredError; -import io.a2a.spec.VersionNotSupportedError; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.Event; +import io.a2a.spec.ExtendedAgentCardNotConfiguredError; +import io.a2a.spec.ExtensionSupportRequiredError; import io.a2a.spec.GetTaskPushNotificationConfigParams; import io.a2a.spec.InternalError; import io.a2a.spec.InvalidRequestError; -import io.a2a.spec.ListTasksParams; import io.a2a.spec.ListTaskPushNotificationConfigParams; +import io.a2a.spec.ListTasksParams; import io.a2a.spec.Message; import io.a2a.spec.MessageSendParams; import io.a2a.spec.PushNotificationConfig; @@ -78,6 +80,7 @@ import io.a2a.spec.TaskStatusUpdateEvent; import io.a2a.spec.TextPart; import io.a2a.spec.UnsupportedOperationError; +import io.a2a.spec.VersionNotSupportedError; import mutiny.zero.ZeroPublisher; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Disabled; @@ -92,7 +95,7 @@ public class JSONRPCHandlerTest extends AbstractA2ARequestHandlerTest { @Test public void testOnGetTaskSuccess() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.id())); GetTaskResponse response = handler.onGetTask(request, callContext); assertEquals(request.getId(), response.getId()); @@ -113,7 +116,7 @@ public void testOnGetTaskNotFound() throws Exception { @Test public void testOnCancelTaskSuccess() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { // We need to cancel the task or the EventConsumer never finds a 'final' event. @@ -138,7 +141,7 @@ public void testOnCancelTaskSuccess() throws Exception { @Test public void testOnCancelTaskNotSupported() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { throw new UnsupportedOperationError(); @@ -174,42 +177,13 @@ public void testOnMessageNewMessageSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - Mockito.doCallRealMethod().when(mock).createAgentRunnableDoneCallback(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - } - @Test public void testOnMessageNewMessageWithExistingTaskSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getMessage()); }; @@ -220,38 +194,9 @@ public void testOnMessageNewMessageWithExistingTaskSuccess() { SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); - // The Python implementation returns a Task here, but then again they are using hardcoded mocks and - // bypassing the whole EventQueue. - // If we were to send a Task in agentExecutorExecute EventConsumer.consumeAll() would not exit due to - // the Task not having a 'final' state - // - // See testOnMessageNewMessageWithExistingTaskSuccessMocks() for a test more similar to the Python implementation Assertions.assertSame(message, response.getResult()); } - @Test - public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - Message message = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .build(); - SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response; - try (MockedConstruction mocked = Mockito.mockConstruction( - EventConsumer.class, - (mock, context) -> { - Mockito.doReturn(ZeroPublisher.fromItems(wrapEvent(MINIMAL_TASK))).when(mock).consumeAll(); - })) { - response = handler.onMessageSend(request, callContext); - } - assertNull(response.getError()); - Assertions.assertSame(MINIMAL_TASK, response.getResult()); - - } - @Test public void testOnMessageError() { // See testMessageOnErrorMocks() for a test more similar to the Python implementation, using mocks for @@ -352,9 +297,11 @@ public void onComplete() { @Test public void testOnMessageStreamNewMessageMultipleEventsSuccess() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + // We'll verify persistence by checking TaskStore after streaming completes JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - // Create multiple events to be sent during streaming + // Create multiple events to be sent during streaming Task taskEvent = Task.builder(MINIMAL_TASK) .status(new TaskStatus(TaskState.WORKING)) .build(); @@ -429,8 +376,8 @@ public void onComplete() { } }); - // Wait for all events to be received - Assertions.assertTrue(latch.await(2, TimeUnit.SECONDS), + // Wait for all events to be received (increased timeout for async processing) + assertTrue(latch.await(10, TimeUnit.SECONDS), "Expected to receive 3 events within timeout"); // Assert no error occurred during streaming @@ -456,6 +403,17 @@ public void onComplete() { "Third event should be a TaskStatusUpdateEvent"); assertEquals(MINIMAL_TASK.id(), receivedStatus.taskId()); assertEquals(TaskState.COMPLETED, receivedStatus.status().state()); + + // Verify events were persisted to TaskStore (poll for final state) + for (int i = 0; i < 50; i++) { + Task storedTask = taskStore.get(MINIMAL_TASK.id()); + if (storedTask != null && storedTask.status() != null + && TaskState.COMPLETED.equals(storedTask.status().state())) { + return; // Success - task finalized in TaskStore + } + Thread.sleep(100); + } + fail("Task should have been finalized in TaskStore within timeout"); } @Test @@ -538,7 +496,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception Task task = Task.builder(MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); Message message = Message.builder(MESSAGE) .taskId(task.id()) @@ -583,7 +541,7 @@ public void onComplete() { }); }); - Assertions.assertTrue(latch.await(1, TimeUnit.SECONDS)); + assertTrue(latch.await(1, TimeUnit.SECONDS)); subscriptionRef.get().cancel(); // The Python implementation has several events emitted since it uses mocks. // @@ -608,7 +566,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() { Task task = Task.builder(MINIMAL_TASK) .history(new ArrayList<>()) .build(); - taskStore.save(task); + taskStore.save(task, false); // This is used to send events from a mock List events = List.of( @@ -682,7 +640,7 @@ public void onComplete() { @Test public void testSetPushNotificationConfigSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); TaskPushNotificationConfig taskPushConfig = new TaskPushNotificationConfig( @@ -704,7 +662,7 @@ public void testSetPushNotificationConfigSuccess() { @Test public void testGetPushNotificationConfigSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -729,112 +687,124 @@ public void testGetPushNotificationConfigSuccess() { @Test public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Exception { - JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); - - List events = List.of( - MINIMAL_TASK, - TaskArtifactUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .artifact(Artifact.builder() - .artifactId("11") - .parts(new TextPart("text")) - .build()) - .build(), - TaskStatusUpdateEvent.builder() - .taskId(MINIMAL_TASK.id()) - .contextId(MINIMAL_TASK.contextId()) - .status(new TaskStatus(TaskState.COMPLETED)) - .build()); - - agentExecutorExecute = (context, eventQueue) -> { - // Hardcode the events to send here - for (Event event : events) { - eventQueue.enqueueEvent(event); - } - }; - - TaskPushNotificationConfig config = new TaskPushNotificationConfig( - MINIMAL_TASK.id(), - PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); - - SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); - SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); - assertNull(stpnResponse.getError()); - - Message msg = Message.builder(MESSAGE) - .taskId(MINIMAL_TASK.id()) - .build(); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); - - final List results = Collections.synchronizedList(new ArrayList<>()); - final AtomicReference subscriptionRef = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(6); - httpClient.latch = latch; - - Executors.newSingleThreadExecutor().execute(() -> { - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscriptionRef.set(subscription); - subscription.request(1); - } - - @Override - public void onNext(SendStreamingMessageResponse item) { - System.out.println("-> " + item.getResult()); - results.add(item.getResult()); - System.out.println(results); - subscriptionRef.get().request(1); - latch.countDown(); - } - - @Override - public void onError(Throwable throwable) { - subscriptionRef.get().cancel(); - } - - @Override - public void onComplete() { - subscriptionRef.get().cancel(); + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback + + // Use synchronous executor for push notifications to ensure deterministic ordering + // Without this, async push notifications can execute out of order, causing test flakiness + mainEventBusProcessor.setPushNotificationExecutor(Runnable::run); + + try { + JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); + taskStore.save(MINIMAL_TASK, false); + + List events = List.of( + MINIMAL_TASK, + TaskArtifactUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .artifact(Artifact.builder() + .artifactId("11") + .parts(new TextPart("text")) + .build()) + .build(), + TaskStatusUpdateEvent.builder() + .taskId(MINIMAL_TASK.id()) + .contextId(MINIMAL_TASK.contextId()) + .status(new TaskStatus(TaskState.COMPLETED)) + .build()); + + + agentExecutorExecute = (context, eventQueue) -> { + // Hardcode the events to send here + for (Event event : events) { + eventQueue.enqueueEvent(event); } + }; + + TaskPushNotificationConfig config = new TaskPushNotificationConfig( + MINIMAL_TASK.id(), + PushNotificationConfig.builder().id("c295ea44-7543-4f78-b524-7a38915ad6e4").url("http://example.com").build(), "tenant"); + + SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); + assertNull(stpnResponse.getError()); + + Message msg = Message.builder(MESSAGE) + .taskId(MINIMAL_TASK.id()) + .build(); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); + + final List results = Collections.synchronizedList(new ArrayList<>()); + final AtomicReference subscriptionRef = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(6); + httpClient.latch = latch; + + Executors.newSingleThreadExecutor().execute(() -> { + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriptionRef.set(subscription); + subscription.request(1); + } + + @Override + public void onNext(SendStreamingMessageResponse item) { + System.out.println("-> " + item.getResult()); + results.add(item.getResult()); + System.out.println(results); + subscriptionRef.get().request(1); + latch.countDown(); + } + + @Override + public void onError(Throwable throwable) { + subscriptionRef.get().cancel(); + } + + @Override + public void onComplete() { + subscriptionRef.get().cancel(); + } + }); }); - }); - Assertions.assertTrue(latch.await(5, TimeUnit.SECONDS)); - subscriptionRef.get().cancel(); - assertEquals(3, results.size()); - assertEquals(3, httpClient.tasks.size()); - - Task curr = httpClient.tasks.get(0); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); - - curr = httpClient.tasks.get(1); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); - - curr = httpClient.tasks.get(2); - assertEquals(MINIMAL_TASK.id(), curr.id()); - assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); - assertEquals(TaskState.COMPLETED, curr.status().state()); - assertEquals(1, curr.artifacts().size()); - assertEquals(1, curr.artifacts().get(0).parts().size()); - assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + assertTrue(latch.await(5, TimeUnit.SECONDS)); + + subscriptionRef.get().cancel(); + assertEquals(3, results.size()); + assertEquals(3, httpClient.tasks.size()); + + Task curr = httpClient.tasks.get(0); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(0, curr.artifacts() == null ? 0 : curr.artifacts().size()); + + curr = httpClient.tasks.get(1); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(MINIMAL_TASK.status().state(), curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + + curr = httpClient.tasks.get(2); + assertEquals(MINIMAL_TASK.id(), curr.id()); + assertEquals(MINIMAL_TASK.contextId(), curr.contextId()); + assertEquals(TaskState.COMPLETED, curr.status().state()); + assertEquals(1, curr.artifacts().size()); + assertEquals(1, curr.artifacts().get(0).parts().size()); + assertEquals("text", ((TextPart) curr.artifacts().get(0).parts().get(0)).text()); + } finally { + mainEventBusProcessor.setPushNotificationExecutor(null); + } } @Test public void testOnResubscribeExistingTaskSuccess() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); queueManager.createOrTap(MINIMAL_TASK.id()); agentExecutorExecute = (context, eventQueue) -> { @@ -861,6 +831,7 @@ public void testOnResubscribeExistingTaskSuccess() { CompletableFuture future = new CompletableFuture<>(); List results = new ArrayList<>(); + AtomicBoolean receivedInitialTask = new AtomicBoolean(false); response.subscribe(new Flow.Subscriber<>() { private Flow.Subscription subscription; @@ -873,7 +844,20 @@ public void onSubscribe(Flow.Subscription subscription) { @Override public void onNext(SendStreamingMessageResponse item) { - results.add(item.getResult()); + StreamingEventKind event = item.getResult(); + results.add(event); + + // Per A2A Protocol Spec 3.1.6: ENFORCE that first event is Task + if (!receivedInitialTask.get()) { + assertTrue(event instanceof Task, + "First event on resubscribe MUST be Task (current state), but was: " + event.getClass().getSimpleName()); + receivedInitialTask.set(true); + } else { + // Subsequent events should be the expected type (Message in this case) + assertTrue(event instanceof Message, + "Expected Message after initial Task, but was: " + event.getClass().getSimpleName()); + } + subscription.request(1); } @@ -892,16 +876,15 @@ public void onComplete() { future.join(); - // The Python implementation has several events emitted since it uses mocks. - // - // See testOnMessageStreamNewMessageExistingTaskSuccessMocks() for a test more similar to the Python implementation - assertEquals(1, results.size()); + // Verify we received exactly 2 events and the initial Task was received + assertEquals(2, results.size()); + assertTrue(receivedInitialTask.get(), "Should have received initial Task event"); } @Test public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); queueManager.createOrTap(MINIMAL_TASK.id()); List events = List.of( @@ -1060,7 +1043,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1107,7 +1090,7 @@ public void onComplete() { if (results.get(0).getError() != null && results.get(0).getError() instanceof InvalidRequestError ire) { assertEquals("Streaming is not supported by the agent", ire.getMessage()); } else { - Assertions.fail("Expected a response containing an error"); + fail("Expected a response containing an error"); } } @@ -1115,7 +1098,7 @@ public void onComplete() { public void testPushNotificationsNotSupportedError() { AgentCard card = createAgentCard(true, false, true); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); TaskPushNotificationConfig config = new TaskPushNotificationConfig( @@ -1135,12 +1118,11 @@ public void testPushNotificationsNotSupportedError() { @Test public void testOnGetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); GetTaskPushNotificationConfigRequest request = new GetTaskPushNotificationConfigRequest("id", new GetTaskPushNotificationConfigParams(MINIMAL_TASK.id())); @@ -1154,12 +1136,11 @@ public void testOnGetPushNotificationNoPushNotifierConfig() { @Test public void testOnSetPushNotificationNoPushNotifierConfig() { // Create request handler without a push notifier - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); TaskPushNotificationConfig config = new TaskPushNotificationConfig( @@ -1246,12 +1227,11 @@ public void testDefaultRequestHandlerWithCustomComponents() { @Test public void testOnMessageSendErrorHandling() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); AgentCard card = createAgentCard(false, true, false); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); Message message = Message.builder(MESSAGE) .taskId(MINIMAL_TASK.id()) @@ -1280,7 +1260,7 @@ public void testOnMessageSendErrorHandling() { @Test public void testOnMessageSendTaskIdMismatch() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = ((context, eventQueue) -> { eventQueue.enqueueEvent(MINIMAL_TASK); @@ -1293,16 +1273,17 @@ public void testOnMessageSendTaskIdMismatch() { } @Test - public void testOnMessageStreamTaskIdMismatch() { + public void testOnMessageStreamTaskIdMismatch() throws InterruptedException { + // Note: Do NOT set callback - DefaultRequestHandler has a permanent callback JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); - agentExecutorExecute = ((context, eventQueue) -> { - eventQueue.enqueueEvent(MINIMAL_TASK); - }); + agentExecutorExecute = ((context, eventQueue) -> { + eventQueue.enqueueEvent(MINIMAL_TASK); + }); - SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request, callContext); + SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); CompletableFuture future = new CompletableFuture<>(); List results = new ArrayList<>(); @@ -1347,7 +1328,7 @@ public void onComplete() { @Test public void testListPushNotificationConfig() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1378,7 +1359,7 @@ public void testListPushNotificationConfig() { public void testListPushNotificationConfigNotSupported() { AgentCard card = createAgentCard(true, false, true); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1404,10 +1385,9 @@ public void testListPushNotificationConfigNotSupported() { @Test public void testListPushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1442,7 +1422,7 @@ public void testListPushNotificationConfigTaskNotFound() { @Test public void testDeletePushNotificationConfig() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1470,7 +1450,7 @@ public void testDeletePushNotificationConfig() { public void testDeletePushNotificationConfigNotSupported() { AgentCard card = createAgentCard(true, false, true); JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1496,10 +1476,10 @@ public void testDeletePushNotificationConfigNotSupported() { @Test public void testDeletePushNotificationConfigNoPushConfigStore() { - DefaultRequestHandler requestHandler = DefaultRequestHandler.create( - executor, taskStore, queueManager, null, null, internalExecutor); + DefaultRequestHandler requestHandler = + DefaultRequestHandler.create(executor, taskStore, queueManager, null, mainEventBusProcessor, internalExecutor, internalExecutor); JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorExecute = (context, eventQueue) -> { eventQueue.enqueueEvent(context.getTask() != null ? context.getTask() : context.getMessage()); }; @@ -1593,7 +1573,7 @@ public void onComplete() { }); // The main thread should not be blocked - we should be able to continue immediately - Assertions.assertTrue(streamStarted.await(100, TimeUnit.MILLISECONDS), + assertTrue(streamStarted.await(100, TimeUnit.MILLISECONDS), "Streaming subscription should start quickly without blocking main thread"); // This proves the main thread is not blocked - we can do other work @@ -1603,11 +1583,11 @@ public void onComplete() { mainThreadBlocked.set(false); // If we get here, main thread was not blocked // Wait for the actual event processing to complete - Assertions.assertTrue(eventProcessed.await(2, TimeUnit.SECONDS), + assertTrue(eventProcessed.await(2, TimeUnit.SECONDS), "Event should be processed within reasonable time"); // Verify we received the event and main thread was not blocked - Assertions.assertTrue(eventReceived.get(), "Should have received streaming event"); + assertTrue(eventReceived.get(), "Should have received streaming event"); Assertions.assertFalse(mainThreadBlocked.get(), "Main thread should not have been blocked"); } @@ -1646,7 +1626,7 @@ public void testExtensionSupportRequiredErrorOnMessageSend() { SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(ExtensionSupportRequiredError.class, response.getError()); - Assertions.assertTrue(response.getError().getMessage().contains("https://example.com/test-extension")); + assertTrue(response.getError().getMessage().contains("https://example.com/test-extension")); assertNull(response.getResult()); } @@ -1715,7 +1695,7 @@ public void onComplete() { assertEquals(1, results.size()); assertInstanceOf(ExtensionSupportRequiredError.class, results.get(0).getError()); - Assertions.assertTrue(results.get(0).getError().getMessage().contains("https://example.com/streaming-extension")); + assertTrue(results.get(0).getError().getMessage().contains("https://example.com/streaming-extension")); assertNull(results.get(0).getResult()); } @@ -1805,7 +1785,7 @@ public void testVersionNotSupportedErrorOnMessageSend() { SendMessageResponse response = handler.onMessageSend(request, contextWithVersion); assertInstanceOf(VersionNotSupportedError.class, response.getError()); - Assertions.assertTrue(response.getError().getMessage().contains("2.0")); + assertTrue(response.getError().getMessage().contains("2.0")); assertNull(response.getResult()); } @@ -1876,12 +1856,12 @@ public void onComplete() { }); // Wait for async processing - Assertions.assertTrue(latch.await(2, TimeUnit.SECONDS), "Expected to receive error event within timeout"); + assertTrue(latch.await(2, TimeUnit.SECONDS), "Expected to receive error event within timeout"); assertEquals(1, results.size()); SendStreamingMessageResponse result = results.get(0); assertInstanceOf(VersionNotSupportedError.class, result.getError()); - Assertions.assertTrue(result.getError().getMessage().contains("2.0")); + assertTrue(result.getError().getMessage().contains("2.0")); assertNull(result.getResult()); } diff --git a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java index 3ffb56c5f..3273d6119 100644 --- a/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java +++ b/transport/rest/src/main/java/io/a2a/transport/rest/handler/RestHandler.java @@ -399,32 +399,46 @@ private Flow.Publisher convertToSendStreamingMessageResponse( Flow.Publisher publisher) { // We can't use the normal convertingProcessor since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload + log.log(Level.FINE, "REST: convertToSendStreamingMessageResponse called, creating ZeroPublisher"); return ZeroPublisher.create(createTubeConfig(), tube -> { + log.log(Level.FINE, "REST: ZeroPublisher tube created, starting CompletableFuture.runAsync"); CompletableFuture.runAsync(() -> { + log.log(Level.FINE, "REST: Inside CompletableFuture, subscribing to EventKind publisher"); publisher.subscribe(new Flow.Subscriber() { Flow.@Nullable Subscription subscription; @Override public void onSubscribe(Flow.Subscription subscription) { + log.log(Level.FINE, "REST: onSubscribe called, storing subscription and requesting first event"); this.subscription = subscription; subscription.request(1); } @Override public void onNext(StreamingEventKind item) { + log.log(Level.FINE, "REST: onNext called with event: {0}", item.getClass().getSimpleName()); try { String payload = JsonFormat.printer().omittingInsignificantWhitespace().print(ProtoUtils.ToProto.taskOrMessageStream(item)); + log.log(Level.FINE, "REST: Converted to JSON, sending via tube: {0}", payload.substring(0, Math.min(100, payload.length()))); tube.send(payload); + log.log(Level.FINE, "REST: tube.send() completed, requesting next event from EventConsumer"); + // Request next event from EventConsumer (Chain 1: EventConsumer → RestHandler) + // This is safe because ZeroPublisher buffers items + // Chain 2 (ZeroPublisher → MultiSseSupport) controls actual delivery via request(1) in onWriteDone() if (subscription != null) { subscription.request(1); + } else { + log.log(Level.WARNING, "REST: subscription is null in onNext!"); } } catch (InvalidProtocolBufferException ex) { + log.log(Level.SEVERE, "REST: JSON conversion failed", ex); onError(ex); } } @Override public void onError(Throwable throwable) { + log.log(Level.SEVERE, "REST: onError called", throwable); if (throwable instanceof A2AError jsonrpcError) { tube.send(new HTTPRestErrorResponse(jsonrpcError).toJson()); } else { @@ -435,6 +449,7 @@ public void onError(Throwable throwable) { @Override public void onComplete() { + log.log(Level.FINE, "REST: onComplete called, calling tube.complete()"); tube.complete(); } }); diff --git a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java index 7d930415b..db3ff97aa 100644 --- a/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java +++ b/transport/rest/src/test/java/io/a2a/transport/rest/handler/RestHandlerTest.java @@ -30,7 +30,7 @@ public class RestHandlerTest extends AbstractA2ARequestHandlerTest { @Test public void testGetTaskSuccess() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.getTask(MINIMAL_TASK.id(), 0, "", callContext); @@ -59,7 +59,7 @@ public void testGetTaskNotFound() { @Test public void testListTasksStatusWireString() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.listTasks(null, "submitted", null, null, null, null, null, "", callContext); @@ -162,7 +162,7 @@ public void testSendMessageEmptyBody() { @Test public void testCancelTaskSuccess() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); agentExecutorCancel = (context, eventQueue) -> { // We need to cancel the task or the EventConsumer never finds a 'final' event. @@ -246,7 +246,7 @@ public void testSendStreamingMessageNotSupported() { @Test public void testPushNotificationConfigSuccess() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); String requestBody = """ { @@ -293,7 +293,7 @@ public void testPushNotificationConfigNotSupported() { @Test public void testGetPushNotificationConfig() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // First, create a push notification config String createRequestBody = """ @@ -322,7 +322,7 @@ public void testGetPushNotificationConfig() { @Test public void testDeletePushNotificationConfig() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.deleteTaskPushNotificationConfiguration(MINIMAL_TASK.id(), "default-config-id", "", callContext); Assertions.assertEquals(204, response.getStatusCode()); } @@ -330,7 +330,7 @@ public void testDeletePushNotificationConfig() { @Test public void testListPushNotificationConfigs() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); RestHandler.HTTPRestResponse response = handler.listTaskPushNotificationConfigurations(MINIMAL_TASK.id(), 0, "", "", callContext); @@ -894,7 +894,7 @@ public void testListTasksNegativeTimestampReturns422() { @Test public void testListTasksUnixMillisecondsTimestamp() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // Unix milliseconds timestamp should be accepted String timestampMillis = String.valueOf(System.currentTimeMillis() - 10000); @@ -909,7 +909,7 @@ public void testListTasksUnixMillisecondsTimestamp() { @Test public void testListTasksProtobufEnumStatus() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // Protobuf enum format (TASK_STATE_SUBMITTED) should be accepted RestHandler.HTTPRestResponse response = handler.listTasks(null, "TASK_STATE_SUBMITTED", null, null, @@ -923,7 +923,7 @@ public void testListTasksProtobufEnumStatus() { @Test public void testListTasksEnumConstantStatus() { RestHandler handler = new RestHandler(CARD, requestHandler, internalExecutor); - taskStore.save(MINIMAL_TASK); + taskStore.save(MINIMAL_TASK, false); // Enum constant format (SUBMITTED) should be accepted RestHandler.HTTPRestResponse response = handler.listTasks(null, "SUBMITTED", null, null, From 929217e449a1e1d4740a15f123e513aa5729fe63 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Tue, 3 Feb 2026 11:16:12 +0000 Subject: [PATCH 2/8] Fix flaky test --- .../replicated/core/ReplicatedQueueManagerTest.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java index 7f53bfbfd..b18198aaa 100644 --- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java @@ -308,12 +308,15 @@ void testParallelReplicationBehavior() throws InterruptedException { int numThreads = 10; int eventsPerThread = 5; int expectedEventCount = (numThreads / 2) * eventsPerThread; // Only normal enqueues + int totalEventCount = numThreads * eventsPerThread; // All events (normal + replicated) ExecutorService executor = Executors.newFixedThreadPool(numThreads); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(numThreads); - // Set up callback to wait for all events to be processed by MainEventBusProcessor - CountDownLatch processingLatch = new CountDownLatch(expectedEventCount); + // Set up callback to wait for ALL events to be processed by MainEventBusProcessor + // Must wait for all 50 events (25 normal + 25 replicated) to ensure all normal events + // have triggered replication before we check the count + CountDownLatch processingLatch = new CountDownLatch(totalEventCount); mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { @Override public void onEventProcessed(String tid, io.a2a.spec.Event event) { From d53e4650ec4ecd36bd51480dd2212293681e123b Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Tue, 3 Feb 2026 11:29:16 +0000 Subject: [PATCH 3/8] Improve comments after Gemini review --- .../DefaultRequestHandler.java | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index 77ba83b8b..baf5ad4a7 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -515,12 +515,16 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte LOGGER.debug("DefaultRequestHandler: Step 3 - Consumption completed for task {}", taskId.get()); } - // NOTE: We do NOT wait for MainEventBusProcessor to finalize the task here. - // This would require blocking, which breaks gRPC (Vert.x event loop). - // In practice, events are processed within milliseconds, so the race - // condition where TaskStore is not fully updated is minimal. - // For platform-agnostic SDK design, we accept this minor race condition - // rather than blocking event loop threads. + // Step 4: Implicit guarantee of persistence via consumption completion + // We do NOT add an explicit wait for MainEventBusProcessor here because: + // 1. MainEventBusProcessor persists BEFORE distributing to ChildQueues + // 2. Step 3 (consumption completion) already guarantees all consumed events are persisted + // 3. Adding another explicit synchronization point would require exposing + // MainEventBusProcessor internals and blocking event loop threads + // + // Note: For fire-and-forget tasks, if the agent is still running after Step 1 timeout, + // it may enqueue additional events. These will be persisted asynchronously but won't + // be included in the task state returned to the client (already consumed in Step 3). } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -538,7 +542,10 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte throw new InternalError(msg); } - // Step 5: Fetch the final task state from TaskStore (now guaranteed persisted) + // Step 5: Fetch the current task state from TaskStore + // All events consumed in Step 3 are guaranteed persisted (MainEventBusProcessor + // ordering: persist → distribute → consume). This returns the persisted state + // including all consumed events and artifacts. String nonNullTaskId = Objects.requireNonNull(taskId.get(), "taskId cannot be null"); Task updatedTask = taskStore.get(nonNullTaskId); if (updatedTask != null) { From de93002a4ef82b75b64c282fc90da16628f8c8ae Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Tue, 3 Feb 2026 11:29:16 +0000 Subject: [PATCH 4/8] Improve comments after Gemini review --- .../server/requesthandlers/DefaultRequestHandler.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index baf5ad4a7..c476c8741 100644 --- a/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -481,13 +481,13 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte } if (blocking && interruptedOrNonBlocking) { - // For blocking calls: ensure all events are persisted to TaskStore before returning + // For blocking calls: ensure all consumed events are persisted to TaskStore before returning // Order of operations is critical to avoid circular dependency and race conditions: - // 1. Wait for agent to finish enqueueing events + // 1. Wait for agent to finish enqueueing events (or timeout) // 2. Close the queue to signal consumption can complete // 3. Wait for consumption to finish processing events - // 4. Wait for MainEventBusProcessor to persist final state to TaskStore - // 5. Fetch final task state from TaskStore (now guaranteed persisted) + // 4. (Implicit) MainEventBusProcessor persistence guarantee via consumption completion + // 5. Fetch current task state from TaskStore (includes all consumed & persisted events) LOGGER.debug("DefaultRequestHandler: Entering blocking fire-and-forget handling for task {}", taskId.get()); try { @@ -550,7 +550,7 @@ public EventKind onMessageSend(MessageSendParams params, ServerCallContext conte Task updatedTask = taskStore.get(nonNullTaskId); if (updatedTask != null) { kind = updatedTask; - LOGGER.debug("DefaultRequestHandler: Step 5 - Fetched final task for {} with state {} and {} artifacts", + LOGGER.debug("DefaultRequestHandler: Step 5 - Fetched current task for {} with state {} and {} artifacts", taskId.get(), updatedTask.status().state(), updatedTask.artifacts().size()); } else { From 2e5b3cc502666ead686c0f945bcb94e10e665a88 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Tue, 3 Feb 2026 15:17:26 +0000 Subject: [PATCH 5/8] Close the temporary child queue in ReplicatedQueueManager --- .../core/ReplicatedQueueManager.java | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java index d232dea98..15a621712 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java @@ -113,21 +113,26 @@ public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent // // If MainQueue doesn't exist, create it. This handles late-arriving replicated events // for tasks that were created on another instance. + EventQueue childQueue = null; // Track ChildQueue we might create EventQueue mainQueue = delegate.get(replicatedEvent.getTaskId()); - if (mainQueue == null) { - // MainQueue doesn't exist - create it by calling createOrTap and then get the MainQueue - // Replicated events should always have real task IDs (not temp IDs) because - // replication now happens AFTER TaskStore persistence in MainEventBusProcessor - LOGGER.debug("Creating MainQueue for replicated event on task {}", replicatedEvent.getTaskId()); - delegate.createOrTap(replicatedEvent.getTaskId()); - mainQueue = delegate.get(replicatedEvent.getTaskId()); - } + try { + if (mainQueue == null) { + LOGGER.debug("Creating MainQueue for replicated event on task {}", replicatedEvent.getTaskId()); + childQueue = delegate.createOrTap(replicatedEvent.getTaskId()); // Creates MainQueue + returns ChildQueue + mainQueue = delegate.get(replicatedEvent.getTaskId()); // Get MainQueue from map + } - if (mainQueue != null) { - mainQueue.enqueueItem(replicatedEvent); - } else { - LOGGER.warn("MainQueue not found for task {}, cannot enqueue replicated event. This may happen if the queue was already cleaned up.", - replicatedEvent.getTaskId()); + if (mainQueue != null) { + mainQueue.enqueueItem(replicatedEvent); + } else { + LOGGER.warn( + "MainQueue not found for task {}, cannot enqueue replicated event. This may happen if the queue was already cleaned up.", + replicatedEvent.getTaskId()); + } + } finally { + if (childQueue != null) { + childQueue.close(); // Close the ChildQueue we created (not MainQueue!) + } } } From 4f092e2249fc071ad4ec1f5db8e69d50ee204af9 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Wed, 4 Feb 2026 14:47:13 +0000 Subject: [PATCH 6/8] Only send push notifications for non-replicated events --- .../java/io/a2a/server/events/MainEventBusProcessor.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java index fbfc4aa0b..8b3dc6fa3 100644 --- a/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java +++ b/server-common/src/main/java/io/a2a/server/events/MainEventBusProcessor.java @@ -176,12 +176,12 @@ private void processEvent(MainEventBusContext context) { taskId, event.getClass().getSimpleName()); Event eventToDistribute = null; + boolean isReplicated = context.eventQueueItem().isReplicated(); try { // Step 1: Update TaskStore FIRST (persistence before clients see it) // If this throws, we distribute an error to ensure "persist before client visibility" try { - boolean isReplicated = context.eventQueueItem().isReplicated(); boolean isFinal = updateTaskStore(taskId, event, isReplicated); eventToDistribute = event; // Success - distribute original event @@ -209,8 +209,9 @@ private void processEvent(MainEventBusContext context) { eventToDistribute = new InternalError(errorMessage); } - // Step 2: Send push notification AFTER successful persistence - if (eventToDistribute == event) { + // Step 2: Send push notification AFTER successful persistence (only from active node) + // Skip push notifications for replicated events to avoid duplicate notifications in multi-instance deployments + if (eventToDistribute == event && !isReplicated) { // Capture task state immediately after persistence, before going async // This ensures we send the task as it existed when THIS event was processed, // not whatever state might exist later when the async callback executes From 49f6efee5df3ff82144444cf631a7b742e03d103 Mon Sep 17 00:00:00 2001 From: Emmanuel Hugonnet Date: Wed, 4 Feb 2026 18:02:48 +0100 Subject: [PATCH 7/8] Updating test using a CyclicBarrier for better synchronization Signed-off-by: Emmanuel Hugonnet --- .../core/ReplicatedQueueManagerTest.java | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java index b18198aaa..14b4c1f51 100644 --- a/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java +++ b/extras/queue-manager-replicated/core/src/test/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManagerTest.java @@ -312,6 +312,14 @@ void testParallelReplicationBehavior() throws InterruptedException { ExecutorService executor = Executors.newFixedThreadPool(numThreads); CountDownLatch startLatch = new CountDownLatch(1); CountDownLatch doneLatch = new CountDownLatch(numThreads); + + // Use CyclicBarrier for better thread synchronization + // This ensures all threads start their work at approximately the same time + java.util.concurrent.CyclicBarrier barrier = new java.util.concurrent.CyclicBarrier(numThreads); + + // Track processed events for better diagnostics on failure + java.util.concurrent.CopyOnWriteArrayList processedEvents = + new java.util.concurrent.CopyOnWriteArrayList<>(); // Set up callback to wait for ALL events to be processed by MainEventBusProcessor // Must wait for all 50 events (25 normal + 25 replicated) to ensure all normal events @@ -320,6 +328,7 @@ void testParallelReplicationBehavior() throws InterruptedException { mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() { @Override public void onEventProcessed(String tid, io.a2a.spec.Event event) { + processedEvents.add(event); processingLatch.countDown(); } @@ -335,6 +344,7 @@ public void onTaskFinalized(String tid) { executor.submit(() -> { try { startLatch.await(); + barrier.await(); // Synchronize thread starts for better interleaving for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId(taskId) // Use same taskId as queue @@ -343,10 +353,11 @@ public void onTaskFinalized(String tid) { .isFinal(false) .build(); queue.enqueueEvent(event); - Thread.sleep(1); // Small delay to interleave operations } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + } catch (java.util.concurrent.BrokenBarrierException e) { + throw new RuntimeException("Barrier broken", e); } finally { doneLatch.countDown(); } @@ -359,6 +370,7 @@ public void onTaskFinalized(String tid) { executor.submit(() -> { try { startLatch.await(); + barrier.await(); // Synchronize thread starts for better interleaving for (int j = 0; j < eventsPerThread; j++) { TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder() .taskId(taskId) // Use same taskId as queue @@ -368,10 +380,11 @@ public void onTaskFinalized(String tid) { .build(); ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, event); queueManager.onReplicatedEvent(replicatedEvent); - Thread.sleep(1); // Small delay to interleave operations } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + } catch (java.util.concurrent.BrokenBarrierException e) { + throw new RuntimeException("Barrier broken", e); } finally { doneLatch.countDown(); } @@ -381,25 +394,36 @@ public void onTaskFinalized(String tid) { // Start all threads simultaneously startLatch.countDown(); - // Wait for all threads to complete - assertTrue(doneLatch.await(10, TimeUnit.SECONDS), "All threads should complete within 10 seconds"); + // Wait for all threads to complete with explicit timeout + assertTrue(doneLatch.await(10, TimeUnit.SECONDS), + "All " + numThreads + " threads should complete within 10 seconds"); executor.shutdown(); - assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), "Executor should shutdown within 5 seconds"); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), + "Executor should shutdown within 5 seconds"); // Wait for MainEventBusProcessor to process all events try { - assertTrue(processingLatch.await(10, TimeUnit.SECONDS), - "MainEventBusProcessor should have processed all events within timeout"); + boolean allProcessed = processingLatch.await(10, TimeUnit.SECONDS); + assertTrue(allProcessed, + String.format("MainEventBusProcessor should have processed all %d events within timeout. " + + "Processed: %d, Remaining: %d", + totalEventCount, processedEvents.size(), processingLatch.getCount())); } finally { mainEventBusProcessor.setCallback(null); + queue.close(true, true); } + // Verify we processed the expected number of events + assertEquals(totalEventCount, processedEvents.size(), + "Should have processed exactly " + totalEventCount + " events (normal + replicated)"); + // Only the normal enqueue operations should have triggered replication // numThreads/2 threads * eventsPerThread events each = total expected replication calls int expectedReplicationCalls = (numThreads / 2) * eventsPerThread; assertEquals(expectedReplicationCalls, strategy.getCallCount(), - "Only normal enqueue operations should trigger replication, not replicated events"); + String.format("Only normal enqueue operations should trigger replication, not replicated events. " + + "Expected: %d, Actual: %d", expectedReplicationCalls, strategy.getCallCount())); } @Test From a7443352e0a82d330e9ac94b558e2a58565407fb Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Wed, 4 Feb 2026 17:10:37 +0000 Subject: [PATCH 8/8] Catch close exception, and debug log stacktrace --- .../replicated/core/ReplicatedQueueManager.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java index 15a621712..44dfbe427 100644 --- a/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java +++ b/extras/queue-manager-replicated/core/src/main/java/io/a2a/extras/queuemanager/replicated/core/ReplicatedQueueManager.java @@ -131,7 +131,14 @@ public void onReplicatedEvent(@Observes ReplicatedEventQueueItem replicatedEvent } } finally { if (childQueue != null) { - childQueue.close(); // Close the ChildQueue we created (not MainQueue!) + try { + childQueue.close(); // Close the ChildQueue we created (not MainQueue!) + } catch (Exception ignore) { + // The close is safe, but print a stacktrace just in case + if (LOGGER.isDebugEnabled()) { + ignore.printStackTrace(); + } + } } } }