Skip to content

Commit

Permalink
Fix Racing Condition on Acknowledgement Graceful Shutdown (#1082)
Browse files Browse the repository at this point in the history
BatchingAcknowledgementProcessor had a racing condition where if the processor was stopped while there where still messages to ack in the main ack queue, it would stop polling and fail to acknowledge such messages even if there was time left for it.

This commit fixes this and adds relevant test coverage.
  • Loading branch information
tomazfernandes authored Mar 12, 2024
1 parent d4caa7f commit 2378662
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.awspring.cloud.sqs.listener.TaskExecutorAware;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDateTime;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
Expand All @@ -35,6 +36,7 @@
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -105,7 +107,8 @@ protected CompletableFuture<Void> doOnAcknowledge(Message<T> message) {
logger.warn("Acknowledgement queue full, dropping acknowledgement for message {}",
MessageHeaderUtils.getId(message));
}
logger.trace("Received message {} to ack in {}.", MessageHeaderUtils.getId(message), getId());
logger.trace("Received message {} to ack in {}. Primary queue size: {}", MessageHeaderUtils.getId(message),
getId(), acks.size());
return CompletableFuture.completedFuture(null);
}

Expand Down Expand Up @@ -143,7 +146,11 @@ protected BufferingAcknowledgementProcessor<T> createAcknowledgementProcessor()

@Override
public void doStop() {
this.acknowledgementProcessor.waitAcknowledgementsToFinish();
try {
this.acknowledgementProcessor.waitAcknowledgementsToFinish();
} catch (Exception e) {
logger.error("Error waiting for acknowledgements to finish. Proceeding with shutdown.", e);
}
LifecycleHandler.get().dispose(this.taskScheduler);
}

Expand Down Expand Up @@ -187,7 +194,7 @@ private BufferingAcknowledgementProcessor(BatchingAcknowledgementProcessor<T> pa
public void run() {
logger.debug("Starting acknowledgement processor thread with batchSize: {}", this.ackThreshold);
this.scheduledExecution.start();
while (this.parent.isRunning()) {
while (shouldKeepPollingAcks()) {
try {
Message<T> polledMessage = this.acks.poll(1, TimeUnit.SECONDS);
if (polledMessage != null) {
Expand All @@ -202,6 +209,10 @@ public void run() {
logger.debug("Acknowledgement processor thread stopped");
}

private boolean shouldKeepPollingAcks() {
return this.parent.isRunning() || !this.context.isTimeoutElapsed;
}

private void addMessageToBuffer(Message<T> polledMessage) {
this.context.lock();
try {
Expand All @@ -214,9 +225,33 @@ private void addMessageToBuffer(Message<T> polledMessage) {
}

public void waitAcknowledgementsToFinish() {
waitOnAcknowledgementsIfTimeoutSet();
this.context.isTimeoutElapsed = true;
this.context.lock();
try {
CompletableFuture.allOf(this.context.runningAcks.toArray(new CompletableFuture[] {}))
.get(this.ackShutdownTimeout.toMillis(), TimeUnit.MILLISECONDS);
this.context.acksBuffer.clear();
}
finally {
this.context.unlock();
}
this.context.runningAcks.forEach(future -> future.cancel(true));
}

private void waitOnAcknowledgementsIfTimeoutSet() {
if (Duration.ZERO.equals(this.ackShutdownTimeout)) {
logger.debug("Not waiting for acknowledgements, shutting down.");
return;
}
try {
var endTime = LocalDateTime.now().plus(this.ackShutdownTimeout);
logger.debug("Waiting until {} for acknowledgements to finish", endTime);
while (hasAcksLeft() || hasUnfinishedAcks()) {
if (LocalDateTime.now().isAfter(endTime)) {
throw new TimeoutException();
}
Thread.sleep(200);
}
logger.debug("All acknowledgements completed.");
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
Expand All @@ -231,9 +266,21 @@ public void waitAcknowledgementsToFinish() {
"Error thrown when waiting for acknowledgement tasks to finish in {}. Continuing with shutdown.",
this.parent.getId(), e);
}
if (!this.context.runningAcks.isEmpty()) {
this.context.runningAcks.forEach(future -> future.cancel(true));
}
}

private boolean hasUnfinishedAcks() {
var unfinishedAcks = this.context.runningAcks.stream().filter(Predicate.not(CompletableFuture::isDone))
.toList().size();
logger.trace("{} unfinished acknowledgement batches", unfinishedAcks);
return unfinishedAcks > 0;
}

private boolean hasAcksLeft() {
int messagesInAcks = this.acks.size();
int messagesInAcksBuffer = this.context.acksBuffer.size();
logger.trace("Acknowledgement queue has {} messages.", messagesInAcks);
logger.trace("Acknowledgement buffer has {} messages.", messagesInAcksBuffer);
return messagesInAcksBuffer > 0 || messagesInAcks > 0;
}

}
Expand All @@ -254,7 +301,9 @@ private static class AcknowledgementExecutionContext<T> {

private Instant lastAcknowledgement = Instant.now();

public AcknowledgementExecutionContext(String id, Map<String, BlockingQueue<Message<T>>> acksBuffer,
private volatile boolean isTimeoutElapsed = false;

private AcknowledgementExecutionContext(String id, Map<String, BlockingQueue<Message<T>>> acksBuffer,
Lock ackLock, Supplier<Boolean> runningFunction,
Function<Collection<Message<T>>, CompletableFuture<Void>> executingFunction) {
this.id = id;
Expand Down Expand Up @@ -314,8 +363,8 @@ private List<Message<T>> pollUpToThreshold(String groupKey, BlockingQueue<Messag
private Message<T> pollMessage(String groupKey, BlockingQueue<Message<T>> messages) {
Message<T> polledMessage = messages.poll();
Assert.notNull(polledMessage, "poll should never return null");
logger.trace("Retrieved message {} from the queue for group {}. Queue size: {}",
MessageHeaderUtils.getId(polledMessage), groupKey, messages.size());
logger.trace("Retrieved message {} from the buffer for group {}. Queue size: {} runningAcks: {}",
MessageHeaderUtils.getId(polledMessage), groupKey, messages.size(), this.runningAcks);
return polledMessage;
}

Expand All @@ -327,9 +376,11 @@ private CompletableFuture<Void> execute(Collection<Message<T>> messages) {

private CompletableFuture<Void> manageFuture(CompletableFuture<Void> future) {
this.runningAcks.add(future);
logger.trace("Added future to runningAcks. Total: {}", this.runningAcks.size());
future.whenComplete((v, t) -> {
if (isRunning()) {
this.runningAcks.remove(future);
logger.trace("Removed future from runningAcks. Total: {}", this.runningAcks.size());
}
});
return future;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void shouldAckAfterBatch() throws Exception {
.collect(Collectors.toList());
given(ackExecutor.execute(messages)).willReturn(CompletableFuture.completedFuture(null));
CountDownLatch ackLatch = new CountDownLatch(1);
BatchingAcknowledgementProcessor<String> processor = new BatchingAcknowledgementProcessor<String>() {
BatchingAcknowledgementProcessor<String> processor = new BatchingAcknowledgementProcessor<>() {
@Override
protected CompletableFuture<Void> sendToExecutor(Collection<Message<String>> messagesToAck) {
return super.sendToExecutor(messagesToAck).thenRun(ackLatch::countDown);
Expand All @@ -116,6 +116,79 @@ protected CompletableFuture<Void> sendToExecutor(Collection<Message<String>> mes
then(ackExecutor).should().execute(messages);
}

@Test
void givenBatchingAcknowledgement_whenEnoughTimeout_shouldAcknowledgeAllMessages() throws Exception {
Duration acknowledgementShutdownTimeout = Duration.ofSeconds(10);
boolean shouldWaitAllAcks = true;
testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks);
}

@Test
void givenBatchingAcknowledgement_whenNotEnoughTimeout_shouldStopBeforeAllMessages() throws Exception {
Duration acknowledgementShutdownTimeout = Duration.ofSeconds(2);
boolean shouldWaitAllAcks = false;
testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks);
}

@Test
void givenBatchingAcknowledgement_whenNoTimeout_shouldStopBeforeAllMessages() throws Exception {
Duration acknowledgementShutdownTimeout = Duration.ZERO;
boolean shouldWaitAllAcks = false;
testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks);
}

private static void testAckOnShutdown(Duration acknowledgementShutdownTimeout, boolean shouldWaitAllAcks)
throws InterruptedException {
TaskExecutor executor = new SimpleAsyncTaskExecutor();
List<Message<String>> messages = IntStream.range(0, 100)
.mapToObj(index -> MessageBuilder.withPayload(String.valueOf(index)).build())
.collect(Collectors.toList());

CountDownLatch ackLatch = new CountDownLatch(10);
var ackExecutor = new AcknowledgementExecutor<String>() {

@Override
public CompletableFuture<Void> execute(Collection<Message<String>> messages) {
return CompletableFuture.runAsync(() -> {
try {
int timeToSleep = Integer.parseInt(messages.iterator().next().getPayload()) * 100;
logger.info("Executing a list of {} messages in {} milliseconds..", messages.size(),
timeToSleep);
Thread.sleep(timeToSleep);
logger.info("Executed a list of {} messages. Counting down.", messages.size());
ackLatch.countDown();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}, executor);
}
};

BatchingAcknowledgementProcessor<String> processor = new BatchingAcknowledgementProcessor<>() {
@Override
protected CompletableFuture<Void> sendToExecutor(Collection<Message<String>> messagesToAck) {
return super.sendToExecutor(messagesToAck);
}
};
SqsContainerOptions options = SqsContainerOptions.builder().acknowledgementInterval(ACK_INTERVAL_ZERO)
.acknowledgementThreshold(ACK_THRESHOLD_TEN).acknowledgementOrdering(AcknowledgementOrdering.PARALLEL)
.acknowledgementShutdownTimeout(acknowledgementShutdownTimeout).build();
processor.configure(options);
processor.setTaskExecutor(executor);
processor.setAcknowledgementExecutor(ackExecutor);
processor.setMaxAcknowledgementsPerBatch(MAX_ACKNOWLEDGEMENTS_PER_BATCH_TEN);
processor.setId(ID);
processor.start();

processor.doOnAcknowledge(messages);

processor.stop();
logger.debug("Processor stopped, waiting on latch");
assertThat(ackLatch.await(1, TimeUnit.SECONDS)).isEqualTo(shouldWaitAllAcks);
}

@Test
void shouldAckAfterTime() throws Exception {
given(message.getHeaders()).willReturn(messageHeaders);
Expand Down
4 changes: 2 additions & 2 deletions spring-cloud-aws-sqs/src/test/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
<logger name="io.awspring.cloud.sqs.listener.acknowledgement.handler" level="INFO"/>
<logger name="io.awspring.cloud.sqs.listener.acknowledgement.BatchingAcknowledgementProcessor" level="INFO"/>
<logger name="io.awspring.cloud.sqs.listener.acknowledgement.SqsAcknowledgementExecutor" level="INFO"/>
<logger name="io.awspring.cloud.sqs.operations" level="TRACE"/>
<logger name="io.awspring.cloud.sqs.operations" level="INFO"/>

<logger name="io.awspring.cloud.sqs.integration.BaseSqsIntegrationTest" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsFifoIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsLoadIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsInterceptorIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsMessageConversionIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsTemplateIntegrationTests" level="TRACE"/>
<logger name="io.awspring.cloud.sqs.integration.SqsTemplateIntegrationTests" level="INFO"/>

<root level="warn">
<appender-ref ref="STDOUT"/>
Expand Down

0 comments on commit 2378662

Please sign in to comment.