Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ void testParallelReplicationBehavior() throws InterruptedException {
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch doneLatch = new CountDownLatch(numThreads);

// Use CyclicBarrier for better thread synchronization
// This ensures all threads start their work at approximately the same time
java.util.concurrent.CyclicBarrier barrier = new java.util.concurrent.CyclicBarrier(numThreads);

// Track processed events for better diagnostics on failure
java.util.concurrent.CopyOnWriteArrayList<io.a2a.spec.Event> processedEvents =
new java.util.concurrent.CopyOnWriteArrayList<>();

// Set up callback to wait for ALL events to be processed by MainEventBusProcessor
// Must wait for all 50 events (25 normal + 25 replicated) to ensure all normal events
Expand All @@ -320,6 +328,7 @@ void testParallelReplicationBehavior() throws InterruptedException {
mainEventBusProcessor.setCallback(new io.a2a.server.events.MainEventBusProcessorCallback() {
@Override
public void onEventProcessed(String tid, io.a2a.spec.Event event) {
processedEvents.add(event);
processingLatch.countDown();
}

Expand All @@ -335,6 +344,7 @@ public void onTaskFinalized(String tid) {
executor.submit(() -> {
try {
startLatch.await();
barrier.await(); // Synchronize thread starts for better interleaving
for (int j = 0; j < eventsPerThread; j++) {
TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder()
.taskId(taskId) // Use same taskId as queue
Expand All @@ -343,10 +353,11 @@ public void onTaskFinalized(String tid) {
.isFinal(false)
.build();
queue.enqueueEvent(event);
Thread.sleep(1); // Small delay to interleave operations
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (java.util.concurrent.BrokenBarrierException e) {
throw new RuntimeException("Barrier broken", e);
} finally {
doneLatch.countDown();
}
Expand All @@ -359,6 +370,7 @@ public void onTaskFinalized(String tid) {
executor.submit(() -> {
try {
startLatch.await();
barrier.await(); // Synchronize thread starts for better interleaving
for (int j = 0; j < eventsPerThread; j++) {
TaskStatusUpdateEvent event = TaskStatusUpdateEvent.builder()
.taskId(taskId) // Use same taskId as queue
Expand All @@ -368,10 +380,11 @@ public void onTaskFinalized(String tid) {
.build();
ReplicatedEventQueueItem replicatedEvent = new ReplicatedEventQueueItem(taskId, event);
queueManager.onReplicatedEvent(replicatedEvent);
Thread.sleep(1); // Small delay to interleave operations
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (java.util.concurrent.BrokenBarrierException e) {
throw new RuntimeException("Barrier broken", e);
} finally {
doneLatch.countDown();
}
Expand All @@ -381,25 +394,36 @@ public void onTaskFinalized(String tid) {
// Start all threads simultaneously
startLatch.countDown();

// Wait for all threads to complete
assertTrue(doneLatch.await(10, TimeUnit.SECONDS), "All threads should complete within 10 seconds");
// Wait for all threads to complete with explicit timeout
assertTrue(doneLatch.await(10, TimeUnit.SECONDS),
"All " + numThreads + " threads should complete within 10 seconds");

executor.shutdown();
assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS), "Executor should shutdown within 5 seconds");
assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS),
"Executor should shutdown within 5 seconds");

// Wait for MainEventBusProcessor to process all events
try {
assertTrue(processingLatch.await(10, TimeUnit.SECONDS),
"MainEventBusProcessor should have processed all events within timeout");
boolean allProcessed = processingLatch.await(10, TimeUnit.SECONDS);
assertTrue(allProcessed,
String.format("MainEventBusProcessor should have processed all %d events within timeout. " +
"Processed: %d, Remaining: %d",
totalEventCount, processedEvents.size(), processingLatch.getCount()));
} finally {
mainEventBusProcessor.setCallback(null);
queue.close(true, true);
}

// Verify we processed the expected number of events
assertEquals(totalEventCount, processedEvents.size(),
"Should have processed exactly " + totalEventCount + " events (normal + replicated)");

// Only the normal enqueue operations should have triggered replication
// numThreads/2 threads * eventsPerThread events each = total expected replication calls
int expectedReplicationCalls = (numThreads / 2) * eventsPerThread;
assertEquals(expectedReplicationCalls, strategy.getCallCount(),
"Only normal enqueue operations should trigger replication, not replicated events");
String.format("Only normal enqueue operations should trigger replication, not replicated events. " +
"Expected: %d, Actual: %d", expectedReplicationCalls, strategy.getCallCount()));
}

@Test
Expand Down
Loading