diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java index f1cc77b3c6..7275a1b446 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java @@ -80,6 +80,7 @@ * @author Gary Russell * @author Artem Bilan * @author Venil Noronha + * @author Nathan Xu */ public abstract class MessagingMessageListenerAdapter implements ConsumerSeekAware { @@ -470,8 +471,8 @@ protected void sendResponse(Object result, String topic, @Nullable Object source if (!returnTypeMessage && topic == null) { this.logger.debug(() -> "No replyTopic to handle the reply: " + result); } - else if (result instanceof Message) { - Message reply = checkHeaders(result, topic, source); + else if (result instanceof Message mResult) { + Message reply = checkHeaders(mResult, topic, source); this.replyTemplate.send(reply); } else { @@ -483,8 +484,9 @@ else if (result instanceof Message) { } if (iterableOfMessages || this.splitIterables) { ((Iterable) result).forEach(v -> { - if (v instanceof Message) { - this.replyTemplate.send((Message) v); + if (v instanceof Message mv) { + Message aReply = checkHeaders(mv, topic, source); + this.replyTemplate.send(aReply); } else { this.replyTemplate.send(topic, v); @@ -501,12 +503,12 @@ else if (result instanceof Message) { } } - private Message checkHeaders(Object result, String topic, @Nullable Object source) { // NOSONAR (complexity) - Message reply = (Message) result; + private Message checkHeaders(Message reply, @Nullable String topic, @Nullable Object source) { // NOSONAR (complexity) MessageHeaders headers = reply.getHeaders(); - boolean needsTopic = headers.get(KafkaHeaders.TOPIC) == null; + boolean needsTopic = topic != null && headers.get(KafkaHeaders.TOPIC) == null; boolean sourceIsMessage = source instanceof Message; - boolean needsCorrelation = headers.get(this.correlationHeaderName) == null && sourceIsMessage; + boolean needsCorrelation = headers.get(this.correlationHeaderName) == null && sourceIsMessage + && getCorrelation((Message) source) != null; boolean needsPartition = headers.get(KafkaHeaders.PARTITION) == null && sourceIsMessage && getReplyPartition((Message) source) != null; if (needsTopic || needsCorrelation || needsPartition) { @@ -514,11 +516,10 @@ private Message checkHeaders(Object result, String topic, @Nullable Object so if (needsTopic) { builder.setHeader(KafkaHeaders.TOPIC, topic); } - if (needsCorrelation && sourceIsMessage) { - builder.setHeader(this.correlationHeaderName, - ((Message) source).getHeaders().get(this.correlationHeaderName)); + if (needsCorrelation) { + setCorrelation(builder, (Message) source); } - if (sourceIsMessage && reply.getHeaders().get(KafkaHeaders.REPLY_PARTITION) == null) { + if (needsPartition) { setPartition(builder, (Message) source); } reply = builder.build(); @@ -531,8 +532,8 @@ private void sendSingleResult(Object result, String topic, @Nullable Object sour byte[] correlationId = null; boolean sourceIsMessage = source instanceof Message; if (sourceIsMessage - && ((Message) source).getHeaders().get(this.correlationHeaderName) != null) { - correlationId = ((Message) source).getHeaders().get(this.correlationHeaderName, byte[].class); + && getCorrelation((Message) source) != null) { + correlationId = getCorrelation((Message) source); } if (sourceIsMessage) { sendReplyForMessageSource(result, topic, source, correlationId); @@ -571,6 +572,18 @@ private void sendReplyForMessageSource(Object result, String topic, Object sourc this.replyTemplate.send(builder.build()); } + private void setCorrelation(MessageBuilder builder, Message source) { + byte[] correlationBytes = getCorrelation(source); + if (correlationBytes != null) { + builder.setHeader(this.correlationHeaderName, correlationBytes); + } + } + + @Nullable + private byte[] getCorrelation(Message source) { + return source.getHeaders().get(this.correlationHeaderName, byte[].class); + } + private void setPartition(MessageBuilder builder, Message source) { byte[] partitionBytes = getReplyPartition(source); if (partitionBytes != null) { diff --git a/spring-kafka/src/test/java/org/springframework/kafka/requestreply/ReplyingKafkaTemplateTests.java b/spring-kafka/src/test/java/org/springframework/kafka/requestreply/ReplyingKafkaTemplateTests.java index 319221d7c6..dd7e406701 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/requestreply/ReplyingKafkaTemplateTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/requestreply/ReplyingKafkaTemplateTests.java @@ -100,6 +100,7 @@ /** * @author Gary Russell + * @author Nathan Xu * @since 2.1.3 * */ @@ -116,7 +117,8 @@ ReplyingKafkaTemplateTests.I_REPLY, ReplyingKafkaTemplateTests.I_REQUEST, ReplyingKafkaTemplateTests.J_REPLY, ReplyingKafkaTemplateTests.J_REQUEST, ReplyingKafkaTemplateTests.K_REPLY, ReplyingKafkaTemplateTests.K_REQUEST, - ReplyingKafkaTemplateTests.L_REPLY, ReplyingKafkaTemplateTests.L_REQUEST }) + ReplyingKafkaTemplateTests.L_REPLY, ReplyingKafkaTemplateTests.L_REQUEST, + ReplyingKafkaTemplateTests.M_REPLY, ReplyingKafkaTemplateTests.M_REQUEST }) public class ReplyingKafkaTemplateTests { public static final String A_REPLY = "aReply"; @@ -167,6 +169,10 @@ public class ReplyingKafkaTemplateTests { public static final String L_REQUEST = "lRequest"; + public static final String M_REPLY = "mReply"; + + public static final String M_REQUEST = "mRequest"; + @Autowired private EmbeddedKafkaBroker embeddedKafka; @@ -845,6 +851,24 @@ void requestTimeoutWithMessage() throws Exception { } } + @Test + void testMessageIterableReturn() throws Exception { + ReplyingKafkaTemplate template = createTemplate(M_REPLY); + try { + template.setDefaultReplyTimeout(Duration.ofSeconds(30)); + Headers headers = new RecordHeaders(); + ProducerRecord record = new ProducerRecord<>(M_REQUEST, null, null, null, "foo", headers); + RequestReplyFuture future = template.sendAndReceive(record); + future.getSendFuture().get(10, TimeUnit.SECONDS); // send ok + ConsumerRecord consumerRecord = future.get(30, TimeUnit.SECONDS); + assertThat(consumerRecord.value()).isEqualTo("FOO"); + } + finally { + template.stop(); + template.destroy(); + } + } + @Configuration @EnableKafka public static class Config { @@ -1011,6 +1035,15 @@ public Message handleL(String in) throws InterruptedException { .build(); } + @KafkaListener(id = M_REQUEST, topics = M_REQUEST) + @SendTo // default REPLY_TOPIC header + public List> handleM(String in) throws InterruptedException { + Message message = MessageBuilder.withPayload(in.toUpperCase()) + .setHeader("serverSentAnError", "user error") + .build(); + return Collections.singletonList(message); + } + } @KafkaListener(topics = C_REQUEST, groupId = C_REQUEST)