Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Acknowledgement Graceful Shutdown #1082

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.awspring.cloud.sqs.listener.TaskExecutorAware;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDateTime;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
Expand All @@ -35,6 +36,7 @@
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -105,7 +107,8 @@ protected CompletableFuture<Void> doOnAcknowledge(Message<T> message) {
logger.warn("Acknowledgement queue full, dropping acknowledgement for message {}",
MessageHeaderUtils.getId(message));
}
logger.trace("Received message {} to ack in {}.", MessageHeaderUtils.getId(message), getId());
logger.trace("Received message {} to ack in {}. Primary queue size: {}", MessageHeaderUtils.getId(message),
getId(), acks.size());
return CompletableFuture.completedFuture(null);
}

Expand Down Expand Up @@ -143,7 +146,11 @@ protected BufferingAcknowledgementProcessor<T> createAcknowledgementProcessor()

@Override
public void doStop() {
this.acknowledgementProcessor.waitAcknowledgementsToFinish();
try {
this.acknowledgementProcessor.waitAcknowledgementsToFinish();
} catch (Exception e) {
logger.error("Error waiting for acknowledgements to finish. Proceeding with shutdown.", e);
}
LifecycleHandler.get().dispose(this.taskScheduler);
}

Expand Down Expand Up @@ -187,7 +194,7 @@ private BufferingAcknowledgementProcessor(BatchingAcknowledgementProcessor<T> pa
public void run() {
logger.debug("Starting acknowledgement processor thread with batchSize: {}", this.ackThreshold);
this.scheduledExecution.start();
while (this.parent.isRunning()) {
while (shouldKeepPollingAcks()) {
try {
Message<T> polledMessage = this.acks.poll(1, TimeUnit.SECONDS);
if (polledMessage != null) {
Expand All @@ -202,6 +209,10 @@ public void run() {
logger.debug("Acknowledgement processor thread stopped");
}

private boolean shouldKeepPollingAcks() {
return this.parent.isRunning() || !this.context.isTimeoutElapsed;
}

private void addMessageToBuffer(Message<T> polledMessage) {
this.context.lock();
try {
Expand All @@ -214,9 +225,33 @@ private void addMessageToBuffer(Message<T> polledMessage) {
}

public void waitAcknowledgementsToFinish() {
waitOnAcknowledgementsIfTimeoutSet();
this.context.isTimeoutElapsed = true;
this.context.lock();
try {
CompletableFuture.allOf(this.context.runningAcks.toArray(new CompletableFuture[] {}))
.get(this.ackShutdownTimeout.toMillis(), TimeUnit.MILLISECONDS);
this.context.acksBuffer.clear();
}
finally {
this.context.unlock();
}
this.context.runningAcks.forEach(future -> future.cancel(true));
}

private void waitOnAcknowledgementsIfTimeoutSet() {
if (Duration.ZERO.equals(this.ackShutdownTimeout)) {
logger.debug("Not waiting for acknowledgements, shutting down.");
return;
}
try {
var endTime = LocalDateTime.now().plus(this.ackShutdownTimeout);
logger.debug("Waiting until {} for acknowledgements to finish", endTime);
while (hasAcksLeft() || hasUnfinishedAcks()) {
if (LocalDateTime.now().isAfter(endTime)) {
throw new TimeoutException();
}
Thread.sleep(200);
}
logger.debug("All acknowledgements completed.");
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
Expand All @@ -231,9 +266,21 @@ public void waitAcknowledgementsToFinish() {
"Error thrown when waiting for acknowledgement tasks to finish in {}. Continuing with shutdown.",
this.parent.getId(), e);
}
if (!this.context.runningAcks.isEmpty()) {
this.context.runningAcks.forEach(future -> future.cancel(true));
}
}

private boolean hasUnfinishedAcks() {
var unfinishedAcks = this.context.runningAcks.stream().filter(Predicate.not(CompletableFuture::isDone))
.toList().size();
logger.trace("{} unfinished acknowledgement batches", unfinishedAcks);
return unfinishedAcks > 0;
}

private boolean hasAcksLeft() {
int messagesInAcks = this.acks.size();
int messagesInAcksBuffer = this.context.acksBuffer.size();
logger.trace("Acknowledgement queue has {} messages.", messagesInAcks);
logger.trace("Acknowledgement buffer has {} messages.", messagesInAcksBuffer);
return messagesInAcksBuffer > 0 || messagesInAcks > 0;
}

}
Expand All @@ -254,7 +301,9 @@ private static class AcknowledgementExecutionContext<T> {

private Instant lastAcknowledgement = Instant.now();

public AcknowledgementExecutionContext(String id, Map<String, BlockingQueue<Message<T>>> acksBuffer,
private volatile boolean isTimeoutElapsed = false;

private AcknowledgementExecutionContext(String id, Map<String, BlockingQueue<Message<T>>> acksBuffer,
Lock ackLock, Supplier<Boolean> runningFunction,
Function<Collection<Message<T>>, CompletableFuture<Void>> executingFunction) {
this.id = id;
Expand Down Expand Up @@ -314,8 +363,8 @@ private List<Message<T>> pollUpToThreshold(String groupKey, BlockingQueue<Messag
private Message<T> pollMessage(String groupKey, BlockingQueue<Message<T>> messages) {
Message<T> polledMessage = messages.poll();
Assert.notNull(polledMessage, "poll should never return null");
logger.trace("Retrieved message {} from the queue for group {}. Queue size: {}",
MessageHeaderUtils.getId(polledMessage), groupKey, messages.size());
logger.trace("Retrieved message {} from the buffer for group {}. Queue size: {} runningAcks: {}",
MessageHeaderUtils.getId(polledMessage), groupKey, messages.size(), this.runningAcks);
return polledMessage;
}

Expand All @@ -327,9 +376,11 @@ private CompletableFuture<Void> execute(Collection<Message<T>> messages) {

private CompletableFuture<Void> manageFuture(CompletableFuture<Void> future) {
this.runningAcks.add(future);
logger.trace("Added future to runningAcks. Total: {}", this.runningAcks.size());
future.whenComplete((v, t) -> {
if (isRunning()) {
this.runningAcks.remove(future);
logger.trace("Removed future from runningAcks. Total: {}", this.runningAcks.size());
}
});
return future;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ void shouldAckAfterBatch() throws Exception {
.collect(Collectors.toList());
given(ackExecutor.execute(messages)).willReturn(CompletableFuture.completedFuture(null));
CountDownLatch ackLatch = new CountDownLatch(1);
BatchingAcknowledgementProcessor<String> processor = new BatchingAcknowledgementProcessor<String>() {
BatchingAcknowledgementProcessor<String> processor = new BatchingAcknowledgementProcessor<>() {
@Override
protected CompletableFuture<Void> sendToExecutor(Collection<Message<String>> messagesToAck) {
return super.sendToExecutor(messagesToAck).thenRun(ackLatch::countDown);
Expand All @@ -116,6 +116,79 @@ protected CompletableFuture<Void> sendToExecutor(Collection<Message<String>> mes
then(ackExecutor).should().execute(messages);
}

@Test
void givenBatchingAcknowledgement_whenEnoughTimeout_shouldAcknowledgeAllMessages() throws Exception {
Duration acknowledgementShutdownTimeout = Duration.ofSeconds(10);
boolean shouldWaitAllAcks = true;
testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks);
}

@Test
void givenBatchingAcknowledgement_whenNotEnoughTimeout_shouldStopBeforeAllMessages() throws Exception {
Duration acknowledgementShutdownTimeout = Duration.ofSeconds(2);
boolean shouldWaitAllAcks = false;
testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks);
}

@Test
void givenBatchingAcknowledgement_whenNoTimeout_shouldStopBeforeAllMessages() throws Exception {
Duration acknowledgementShutdownTimeout = Duration.ZERO;
boolean shouldWaitAllAcks = false;
testAckOnShutdown(acknowledgementShutdownTimeout, shouldWaitAllAcks);
}

private static void testAckOnShutdown(Duration acknowledgementShutdownTimeout, boolean shouldWaitAllAcks)
throws InterruptedException {
TaskExecutor executor = new SimpleAsyncTaskExecutor();
List<Message<String>> messages = IntStream.range(0, 100)
.mapToObj(index -> MessageBuilder.withPayload(String.valueOf(index)).build())
.collect(Collectors.toList());

CountDownLatch ackLatch = new CountDownLatch(10);
var ackExecutor = new AcknowledgementExecutor<String>() {

@Override
public CompletableFuture<Void> execute(Collection<Message<String>> messages) {
return CompletableFuture.runAsync(() -> {
try {
int timeToSleep = Integer.parseInt(messages.iterator().next().getPayload()) * 100;
logger.info("Executing a list of {} messages in {} milliseconds..", messages.size(),
timeToSleep);
Thread.sleep(timeToSleep);
logger.info("Executed a list of {} messages. Counting down.", messages.size());
ackLatch.countDown();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
}, executor);
}
};

BatchingAcknowledgementProcessor<String> processor = new BatchingAcknowledgementProcessor<>() {
@Override
protected CompletableFuture<Void> sendToExecutor(Collection<Message<String>> messagesToAck) {
return super.sendToExecutor(messagesToAck);
}
};
SqsContainerOptions options = SqsContainerOptions.builder().acknowledgementInterval(ACK_INTERVAL_ZERO)
.acknowledgementThreshold(ACK_THRESHOLD_TEN).acknowledgementOrdering(AcknowledgementOrdering.PARALLEL)
.acknowledgementShutdownTimeout(acknowledgementShutdownTimeout).build();
processor.configure(options);
processor.setTaskExecutor(executor);
processor.setAcknowledgementExecutor(ackExecutor);
processor.setMaxAcknowledgementsPerBatch(MAX_ACKNOWLEDGEMENTS_PER_BATCH_TEN);
processor.setId(ID);
processor.start();

processor.doOnAcknowledge(messages);

processor.stop();
logger.debug("Processor stopped, waiting on latch");
assertThat(ackLatch.await(1, TimeUnit.SECONDS)).isEqualTo(shouldWaitAllAcks);
}

@Test
void shouldAckAfterTime() throws Exception {
given(message.getHeaders()).willReturn(messageHeaders);
Expand Down
4 changes: 2 additions & 2 deletions spring-cloud-aws-sqs/src/test/resources/logback.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
<logger name="io.awspring.cloud.sqs.listener.acknowledgement.handler" level="INFO"/>
<logger name="io.awspring.cloud.sqs.listener.acknowledgement.BatchingAcknowledgementProcessor" level="INFO"/>
<logger name="io.awspring.cloud.sqs.listener.acknowledgement.SqsAcknowledgementExecutor" level="INFO"/>
<logger name="io.awspring.cloud.sqs.operations" level="TRACE"/>
<logger name="io.awspring.cloud.sqs.operations" level="INFO"/>

<logger name="io.awspring.cloud.sqs.integration.BaseSqsIntegrationTest" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsFifoIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsLoadIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsInterceptorIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsMessageConversionIntegrationTests" level="INFO"/>
<logger name="io.awspring.cloud.sqs.integration.SqsTemplateIntegrationTests" level="TRACE"/>
<logger name="io.awspring.cloud.sqs.integration.SqsTemplateIntegrationTests" level="INFO"/>

<root level="warn">
<appender-ref ref="STDOUT"/>
Expand Down
Loading