diff --git a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessor.java b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessor.java index 177a4c391..562fcb3fb 100644 --- a/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessor.java +++ b/spring-cloud-aws-sqs/src/main/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessor.java @@ -21,6 +21,7 @@ import io.awspring.cloud.sqs.listener.TaskExecutorAware; import java.time.Duration; import java.time.Instant; +import java.time.LocalDateTime; import java.util.Collection; import java.util.Collections; import java.util.HashSet; @@ -35,6 +36,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Function; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -105,7 +107,8 @@ protected CompletableFuture doOnAcknowledge(Message message) { logger.warn("Acknowledgement queue full, dropping acknowledgement for message {}", MessageHeaderUtils.getId(message)); } - logger.trace("Received message {} to ack in {}.", MessageHeaderUtils.getId(message), getId()); + logger.trace("Received message {} to ack in {}. Primary queue size: {}", MessageHeaderUtils.getId(message), + getId(), acks.size()); return CompletableFuture.completedFuture(null); } @@ -143,7 +146,11 @@ protected BufferingAcknowledgementProcessor createAcknowledgementProcessor() @Override public void doStop() { - this.acknowledgementProcessor.waitAcknowledgementsToFinish(); + try { + this.acknowledgementProcessor.waitAcknowledgementsToFinish(); + } catch (Exception e) { + logger.error("Error waiting for acknowledgements to finish. Proceeding with shutdown.", e); + } LifecycleHandler.get().dispose(this.taskScheduler); } @@ -187,7 +194,7 @@ private BufferingAcknowledgementProcessor(BatchingAcknowledgementProcessor pa public void run() { logger.debug("Starting acknowledgement processor thread with batchSize: {}", this.ackThreshold); this.scheduledExecution.start(); - while (this.parent.isRunning()) { + while (shouldKeepPollingAcks()) { try { Message polledMessage = this.acks.poll(1, TimeUnit.SECONDS); if (polledMessage != null) { @@ -202,6 +209,10 @@ public void run() { logger.debug("Acknowledgement processor thread stopped"); } + private boolean shouldKeepPollingAcks() { + return this.parent.isRunning() || !this.context.isTimeoutElapsed; + } + private void addMessageToBuffer(Message polledMessage) { this.context.lock(); try { @@ -214,9 +225,33 @@ private void addMessageToBuffer(Message polledMessage) { } public void waitAcknowledgementsToFinish() { + waitOnAcknowledgementsIfTimeoutSet(); + this.context.isTimeoutElapsed = true; + this.context.lock(); try { - CompletableFuture.allOf(this.context.runningAcks.toArray(new CompletableFuture[] {})) - .get(this.ackShutdownTimeout.toMillis(), TimeUnit.MILLISECONDS); + this.context.acksBuffer.clear(); + } + finally { + this.context.unlock(); + } + this.context.runningAcks.forEach(future -> future.cancel(true)); + } + + private void waitOnAcknowledgementsIfTimeoutSet() { + if (Duration.ZERO.equals(this.ackShutdownTimeout)) { + logger.debug("Not waiting for acknowledgements, shutting down."); + return; + } + try { + var endTime = LocalDateTime.now().plus(this.ackShutdownTimeout); + logger.debug("Waiting until {} for acknowledgements to finish", endTime); + while (hasAcksLeft() || hasUnfinishedAcks()) { + if (LocalDateTime.now().isAfter(endTime)) { + throw new TimeoutException(); + } + Thread.sleep(200); + } + logger.debug("All acknowledgements completed."); } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -231,9 +266,21 @@ public void waitAcknowledgementsToFinish() { "Error thrown when waiting for acknowledgement tasks to finish in {}. Continuing with shutdown.", this.parent.getId(), e); } - if (!this.context.runningAcks.isEmpty()) { - this.context.runningAcks.forEach(future -> future.cancel(true)); - } + } + + private boolean hasUnfinishedAcks() { + var unfinishedAcks = this.context.runningAcks.stream().filter(Predicate.not(CompletableFuture::isDone)) + .toList().size(); + logger.trace("{} unfinished acknowledgement batches", unfinishedAcks); + return unfinishedAcks > 0; + } + + private boolean hasAcksLeft() { + int messagesInAcks = this.acks.size(); + int messagesInAcksBuffer = this.context.acksBuffer.size(); + logger.trace("Acknowledgement queue has {} messages.", messagesInAcks); + logger.trace("Acknowledgement buffer has {} messages.", messagesInAcksBuffer); + return messagesInAcksBuffer > 0 || messagesInAcks > 0; } } @@ -254,7 +301,9 @@ private static class AcknowledgementExecutionContext { private Instant lastAcknowledgement = Instant.now(); - public AcknowledgementExecutionContext(String id, Map>> acksBuffer, + private volatile boolean isTimeoutElapsed = false; + + private AcknowledgementExecutionContext(String id, Map>> acksBuffer, Lock ackLock, Supplier runningFunction, Function>, CompletableFuture> executingFunction) { this.id = id; @@ -314,8 +363,8 @@ private List> pollUpToThreshold(String groupKey, BlockingQueue pollMessage(String groupKey, BlockingQueue> messages) { Message polledMessage = messages.poll(); Assert.notNull(polledMessage, "poll should never return null"); - logger.trace("Retrieved message {} from the queue for group {}. Queue size: {}", - MessageHeaderUtils.getId(polledMessage), groupKey, messages.size()); + logger.trace("Retrieved message {} from the buffer for group {}. Queue size: {} runningAcks: {}", + MessageHeaderUtils.getId(polledMessage), groupKey, messages.size(), this.runningAcks); return polledMessage; } @@ -327,9 +376,11 @@ private CompletableFuture execute(Collection> messages) { private CompletableFuture manageFuture(CompletableFuture future) { this.runningAcks.add(future); + logger.trace("Added future to runningAcks. Total: {}", this.runningAcks.size()); future.whenComplete((v, t) -> { if (isRunning()) { this.runningAcks.remove(future); + logger.trace("Removed future from runningAcks. Total: {}", this.runningAcks.size()); } }); return future; diff --git a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessorTests.java b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessorTests.java index 09ff6e592..1d9fbc351 100644 --- a/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessorTests.java +++ b/spring-cloud-aws-sqs/src/test/java/io/awspring/cloud/sqs/listener/acknowledgement/BatchingAcknowledgementProcessorTests.java @@ -94,7 +94,7 @@ void shouldAckAfterBatch() throws Exception { .collect(Collectors.toList()); given(ackExecutor.execute(messages)).willReturn(CompletableFuture.completedFuture(null)); CountDownLatch ackLatch = new CountDownLatch(1); - BatchingAcknowledgementProcessor processor = new BatchingAcknowledgementProcessor() { + BatchingAcknowledgementProcessor processor = new BatchingAcknowledgementProcessor<>() { @Override protected CompletableFuture sendToExecutor(Collection> messagesToAck) { return super.sendToExecutor(messagesToAck).thenRun(ackLatch::countDown); @@ -116,6 +116,79 @@ protected CompletableFuture sendToExecutor(Collection> mes then(ackExecutor).should().execute(messages); } + @Test + void givenBatchingAcknowledgement_whenEnoughTimeout_shouldAcknowledgeAllMessages() throws Exception { + Duration acknowledgementShutdownTimeout = Duration.ofSeconds(10); + boolean shouldWaitAllAcks = true; + testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks); + } + + @Test + void givenBatchingAcknowledgement_whenNotEnoughTimeout_shouldStopBeforeAllMessages() throws Exception { + Duration acknowledgementShutdownTimeout = Duration.ofSeconds(2); + boolean shouldWaitAllAcks = false; + testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks); + } + + @Test + void givenBatchingAcknowledgement_whenNoTimeout_shouldStopBeforeAllMessages() throws Exception { + Duration acknowledgementShutdownTimeout = Duration.ZERO; + boolean shouldWaitAllAcks = false; + testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks); + } + + private static void testAckOnShutdown(Duration acknowledgementShutdownTimeout, boolean shouldWaitAllAcks) + throws InterruptedException { + TaskExecutor executor = new SimpleAsyncTaskExecutor(); + List> messages = IntStream.range(0, 100) + .mapToObj(index -> MessageBuilder.withPayload(String.valueOf(index)).build()) + .collect(Collectors.toList()); + + CountDownLatch ackLatch = new CountDownLatch(10); + var ackExecutor = new AcknowledgementExecutor() { + + @Override + public CompletableFuture execute(Collection> messages) { + return CompletableFuture.runAsync(() -> { + try { + int timeToSleep = Integer.parseInt(messages.iterator().next().getPayload()) * 100; + logger.info("Executing a list of {} messages in {} milliseconds..", messages.size(), + timeToSleep); + Thread.sleep(timeToSleep); + logger.info("Executed a list of {} messages. Counting down.", messages.size()); + ackLatch.countDown(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + }, executor); + } + }; + + BatchingAcknowledgementProcessor processor = new BatchingAcknowledgementProcessor<>() { + @Override + protected CompletableFuture sendToExecutor(Collection> messagesToAck) { + return super.sendToExecutor(messagesToAck); + } + }; + SqsContainerOptions options = SqsContainerOptions.builder().acknowledgementInterval(ACK_INTERVAL_ZERO) + .acknowledgementThreshold(ACK_THRESHOLD_TEN).acknowledgementOrdering(AcknowledgementOrdering.PARALLEL) + .acknowledgementShutdownTimeout(acknowledgementShutdownTimeout).build(); + processor.configure(options); + processor.setTaskExecutor(executor); + processor.setAcknowledgementExecutor(ackExecutor); + processor.setMaxAcknowledgementsPerBatch(MAX_ACKNOWLEDGEMENTS_PER_BATCH_TEN); + processor.setId(ID); + processor.start(); + + processor.doOnAcknowledge(messages); + + processor.stop(); + logger.debug("Processor stopped, waiting on latch"); + assertThat(ackLatch.await(1, TimeUnit.SECONDS)).isEqualTo(shouldWaitAllAcks); + } + @Test void shouldAckAfterTime() throws Exception { given(message.getHeaders()).willReturn(messageHeaders); diff --git a/spring-cloud-aws-sqs/src/test/resources/logback.xml b/spring-cloud-aws-sqs/src/test/resources/logback.xml index 35f7dd156..99b1679d4 100644 --- a/spring-cloud-aws-sqs/src/test/resources/logback.xml +++ b/spring-cloud-aws-sqs/src/test/resources/logback.xml @@ -37,7 +37,7 @@ - + @@ -45,7 +45,7 @@ - +