diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java index d1caa20dd3..1b2019f957 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java @@ -2643,7 +2643,7 @@ private void doInvokeWithRecords(final ConsumerRecords records) { private boolean checkImmediatePause(Iterator> iterator) { if (isPaused() && this.pauseImmediate) { - Map>> remaining = new HashMap<>(); + Map>> remaining = new LinkedHashMap<>(); while (iterator.hasNext()) { ConsumerRecord next = iterator.next(); remaining.computeIfAbsent(new TopicPartition(next.topic(), next.partition()), @@ -3498,6 +3498,7 @@ private class ListenerConsumerRebalanceListener implements ConsumerRebalanceList @Override public void onPartitionsRevoked(Collection partitions) { this.revoked.addAll(partitions); + removeRevocationsFromPending(partitions); if (this.consumerAwareListener != null) { this.consumerAwareListener.onPartitionsRevokedBeforeCommit(ListenerConsumer.this.consumer, partitions); @@ -3537,6 +3538,23 @@ public void onPartitionsRevoked(Collection partitions) { } } + private void removeRevocationsFromPending(Collection partitions) { + ConsumerRecords remaining = ListenerConsumer.this.remainingRecords; + if (remaining != null && !partitions.isEmpty()) { + Set remainingParts = new LinkedHashSet<>(remaining.partitions()); + remainingParts.removeAll(partitions); + if (!remainingParts.isEmpty()) { + Map>> trimmed = new LinkedHashMap<>(); + remainingParts.forEach(part -> trimmed.computeIfAbsent(part, tp -> remaining.records(tp))); + ListenerConsumer.this.remainingRecords = new ConsumerRecords<>(trimmed); + } + else { + ListenerConsumer.this.remainingRecords = null; + } + ListenerConsumer.this.logger.debug(() -> "Removed " + partitions + " from remaining records"); + } + } + @Override public void onPartitionsAssigned(Collection partitions) { repauseIfNeeded(partitions); @@ -3568,7 +3586,7 @@ public void onPartitionsAssigned(Collection partitions) { } private void repauseIfNeeded(Collection partitions) { - if (isPaused()) { + if (isPaused() || ListenerConsumer.this.remainingRecords != null && !partitions.isEmpty()) { ListenerConsumer.this.consumer.pause(partitions); ListenerConsumer.this.consumerPaused = true; ListenerConsumer.this.logger.warn("Paused consumer resumed by Kafka due to rebalance; " diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/ConcurrentMessageListenerContainerMockTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/ConcurrentMessageListenerContainerMockTests.java index 482e9f286e..6a34cc9982 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/ConcurrentMessageListenerContainerMockTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/ConcurrentMessageListenerContainerMockTests.java @@ -22,6 +22,7 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -33,6 +34,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -897,6 +899,221 @@ void removeFromPartitionPauseRequestedWhenNotAssigned() throws InterruptedExcept assertThat(child.isPartitionPauseRequested(tp0)).isFalse(); rebal.get().onPartitionsAssigned(assignments); verify(consumer, times(2)).pause(any()); // no immediate pause this time + container.stop(); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + void pruneRevokedPartitionsFromRemainingRecordsWhenSeekAfterErrorFalseLagacyAssignor() throws InterruptedException { + TopicPartition tp0 = new TopicPartition("foo", 0); + TopicPartition tp1 = new TopicPartition("foo", 1); + TopicPartition tp2 = new TopicPartition("foo", 2); + TopicPartition tp3 = new TopicPartition("foo", 3); + List allAssignments = Arrays.asList(tp0, tp1, tp2, tp3); + Map>> allRecordMap = new HashMap<>(); + allRecordMap.put(tp0, Collections.singletonList(new ConsumerRecord("foo", 0, 0, null, "bar"))); + allRecordMap.put(tp1, Collections.singletonList(new ConsumerRecord("foo", 1, 0, null, "bar"))); + allRecordMap.put(tp2, Collections.singletonList(new ConsumerRecord("foo", 2, 0, null, "bar"))); + allRecordMap.put(tp3, Collections.singletonList(new ConsumerRecord("foo", 3, 0, null, "bar"))); + ConsumerRecords allRecords = new ConsumerRecords<>(allRecordMap); + List afterRevokeAssignments = Arrays.asList(tp1, tp3); + Map>> afterRevokeRecordMap = new HashMap<>(); + afterRevokeRecordMap.put(tp1, Collections.singletonList(new ConsumerRecord("foo", 1, 0, null, "bar"))); + afterRevokeRecordMap.put(tp3, Collections.singletonList(new ConsumerRecord("foo", 3, 0, null, "bar"))); + ConsumerRecords afterRevokeRecords = new ConsumerRecords<>(afterRevokeRecordMap); + AtomicInteger pollPhase = new AtomicInteger(); + + Consumer consumer = mock(Consumer.class); + AtomicReference rebal = new AtomicReference<>(); + CountDownLatch subscribeLatch = new CountDownLatch(1); + willAnswer(invocation -> { + rebal.set(invocation.getArgument(1)); + subscribeLatch.countDown(); + return null; + }).given(consumer).subscribe(any(Collection.class), any()); + CountDownLatch pauseLatch = new CountDownLatch(1); + AtomicBoolean paused = new AtomicBoolean(); + willAnswer(inv -> { + paused.set(true); + pauseLatch.countDown(); + return null; + }).given(consumer).pause(any()); + CountDownLatch resumeLatch = new CountDownLatch(1); + willAnswer(inv -> { + paused.set(false); + resumeLatch.countDown(); + return null; + }).given(consumer).resume(any()); + ConsumerFactory cf = mock(ConsumerFactory.class); + given(cf.createConsumer(any(), any(), any(), any())).willReturn(consumer); + given(cf.getConfigurationProperties()) + .willReturn(Collections.singletonMap(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")); + ContainerProperties containerProperties = new ContainerProperties("foo"); + containerProperties.setGroupId("grp"); + containerProperties.setMessageListener((MessageListener) rec -> { + throw new RuntimeException("test"); + }); + ConcurrentMessageListenerContainer container = new ConcurrentMessageListenerContainer(cf, + containerProperties); + container.setCommonErrorHandler(new CommonErrorHandler() { + + @Override + public boolean seeksAfterHandling() { + return false; // pause and use remainingRecords + } + + @Override + public boolean handleOne(Exception thrownException, ConsumerRecord record, Consumer consumer, + MessageListenerContainer container) { + + return false; // not handled + } + + }); + CountDownLatch pollLatch = new CountDownLatch(2); + CountDownLatch rebalLatch = new CountDownLatch(1); + CountDownLatch continueLatch = new CountDownLatch(1); + willAnswer(inv -> { + Thread.sleep(50); + pollLatch.countDown(); + switch (pollPhase.getAndIncrement()) { + case 0: + rebal.get().onPartitionsAssigned(allAssignments); + return allRecords; + case 1: + rebal.get().onPartitionsRevoked(allAssignments); + rebal.get().onPartitionsAssigned(afterRevokeAssignments); + rebalLatch.countDown(); + continueLatch.await(10, TimeUnit.SECONDS); + default: + if (paused.get()) { + return ConsumerRecords.empty(); + } + return afterRevokeRecords; + } + }).given(consumer).poll(any()); + container.start(); + assertThat(subscribeLatch.await(10, TimeUnit.SECONDS)).isTrue(); + KafkaMessageListenerContainer child = (KafkaMessageListenerContainer) KafkaTestUtils + .getPropertyValue(container, "containers", List.class).get(0); + assertThat(pollLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(pauseLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(rebalLatch.await(10, TimeUnit.SECONDS)).isTrue(); + ConsumerRecords remaining = KafkaTestUtils.getPropertyValue(child, "listenerConsumer.remainingRecords", + ConsumerRecords.class); + assertThat(remaining).isNull(); + continueLatch.countDown(); + assertThat(resumeLatch.await(10, TimeUnit.SECONDS)).isTrue(); + // no pause when re-assigned because all revoked and remainingRecords == null + verify(consumer).pause(any()); + verify(consumer).resume(any()); + container.stop(); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + @Test + void pruneRevokedPartitionsFromRemainingRecordsWhenSeekAfterErrorFalseCoopAssignor() throws InterruptedException { + TopicPartition tp0 = new TopicPartition("foo", 0); + TopicPartition tp1 = new TopicPartition("foo", 1); + TopicPartition tp2 = new TopicPartition("foo", 2); + TopicPartition tp3 = new TopicPartition("foo", 3); + List allAssignments = Arrays.asList(tp0, tp1, tp2, tp3); + Map>> allRecordMap = new LinkedHashMap<>(); + ConsumerRecord record0 = new ConsumerRecord("foo", 0, 0, null, "bar"); + ConsumerRecord record1 = new ConsumerRecord("foo", 1, 0, null, "bar"); + allRecordMap.put(tp0, Collections.singletonList(record0)); + allRecordMap.put(tp1, Collections.singletonList(record1)); + allRecordMap.put(tp2, Collections.singletonList(new ConsumerRecord("foo", 2, 0, null, "bar"))); + allRecordMap.put(tp3, Collections.singletonList(new ConsumerRecord("foo", 3, 0, null, "bar"))); + ConsumerRecords allRecords = new ConsumerRecords<>(allRecordMap); + List revokedAssignments = Arrays.asList(tp0, tp2); + AtomicInteger pollPhase = new AtomicInteger(); + + Consumer consumer = mock(Consumer.class); + AtomicReference rebal = new AtomicReference<>(); + CountDownLatch subscribeLatch = new CountDownLatch(1); + willAnswer(invocation -> { + rebal.set(invocation.getArgument(1)); + subscribeLatch.countDown(); + return null; + }).given(consumer).subscribe(any(Collection.class), any()); + CountDownLatch pauseLatch = new CountDownLatch(1); + AtomicBoolean paused = new AtomicBoolean(); + willAnswer(inv -> { + paused.set(true); + pauseLatch.countDown(); + return null; + }).given(consumer).pause(any()); + ConsumerFactory cf = mock(ConsumerFactory.class); + given(cf.createConsumer(any(), any(), any(), any())).willReturn(consumer); + given(cf.getConfigurationProperties()) + .willReturn(Collections.singletonMap(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")); + ContainerProperties containerProperties = new ContainerProperties("foo"); + containerProperties.setGroupId("grp"); + List recordsDelivered = new ArrayList<>(); + CountDownLatch consumeLatch = new CountDownLatch(3); + containerProperties.setMessageListener((MessageListener) rec -> { + recordsDelivered.add((ConsumerRecord) rec); + consumeLatch.countDown(); + throw new RuntimeException("test"); + }); + ConcurrentMessageListenerContainer container = new ConcurrentMessageListenerContainer(cf, + containerProperties); + container.setCommonErrorHandler(new CommonErrorHandler() { + + @Override + public boolean seeksAfterHandling() { + return false; // pause and use remainingRecords + } + + @Override + public boolean handleOne(Exception thrownException, ConsumerRecord record, Consumer consumer, + MessageListenerContainer container) { + + return false; // not handled + } + + }); + CountDownLatch pollLatch = new CountDownLatch(2); + CountDownLatch rebalLatch = new CountDownLatch(1); + CountDownLatch continueLatch = new CountDownLatch(1); + willAnswer(inv -> { + Thread.sleep(50); + pollLatch.countDown(); + switch (pollPhase.getAndIncrement()) { + case 0: + rebal.get().onPartitionsAssigned(allAssignments); + return allRecords; + case 1: + rebal.get().onPartitionsRevoked(revokedAssignments); + rebal.get().onPartitionsAssigned(Collections.emptyList()); + rebalLatch.countDown(); + continueLatch.await(10, TimeUnit.SECONDS); + default: + return ConsumerRecords.empty(); + } + }).given(consumer).poll(any()); + container.start(); + assertThat(subscribeLatch.await(10, TimeUnit.SECONDS)).isTrue(); + KafkaMessageListenerContainer child = (KafkaMessageListenerContainer) KafkaTestUtils + .getPropertyValue(container, "containers", List.class).get(0); + assertThat(pollLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(pauseLatch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(rebalLatch.await(10, TimeUnit.SECONDS)).isTrue(); + ConsumerRecords remaining = KafkaTestUtils.getPropertyValue(child, "listenerConsumer.remainingRecords", + ConsumerRecords.class); + assertThat(remaining.count()).isEqualTo(2); + assertThat(remaining.partitions()).contains(tp1, tp3); + continueLatch.countDown(); + verify(consumer, atLeastOnce()).pause(any()); + verify(consumer, never()).resume(any()); + assertThat(consumeLatch.await(10, TimeUnit.SECONDS)).isTrue(); + container.stop(); + assertThat(recordsDelivered).hasSizeGreaterThanOrEqualTo(3); + // partitions 0, 2 revoked during second poll. + assertThat(recordsDelivered.get(0)).isEqualTo(record0); + assertThat(recordsDelivered.get(1)).isEqualTo(record1); + assertThat(recordsDelivered.get(2)).isEqualTo(record1); } public static class TestMessageListener1 implements MessageListener, ConsumerSeekAware {