Skip to content

Commit

Permalink
KAFKA-15661: KIP-951: Server side changes (apache#14444)
Browse files Browse the repository at this point in the history
This is the server side changes to populate the fields in KIP-951. On NOT_LEADER_OR_FOLLOWER errors in both FETCH and PRODUCE the new leader ID and epoch are retrieved from the local cache through ReplicaManager and included in the response, falling back to the metadata cache if they are unavailable there. The endpoint for the new leader is retrieved from the metadata cache. The new fields are all optional (tagged) and an IBP bump was required.

https://cwiki.apache.org/confluence/display/KAFKA/KIP-951%3A+Leader+discovery+optimisations+for+the+client

https://issues.apache.org/jira/browse/KAFKA-15661

Protocol changes: apache#14627

Testing
Benchmarking described here https://cwiki.apache.org/confluence/display/KAFKA/KIP-951%3A+Leader+discovery+optimisations+for+the+client#KIP951:Leaderdiscoveryoptimisationsfortheclient-BenchmarkResults
./gradlew core:test --tests kafka.server.KafkaApisTest

Reviewers: Justine Olshan <jolshan@confluent.io>, David Jacot <djacot@confluent.io>, Jason Gustafson <jason@confluent.io>, Fred Zheng <zhengyd2014@gmail.com>, Mayank Shekhar Narula <mayanks.narula@gmail.com>,  Yang Yang <yayang@uber.com>, David Mao <dmao@confluent.io>, Kirk True <ktrue@confluent.io>
  • Loading branch information
chb2ab authored and yyu1993 committed Feb 15, 2024
1 parent db63f4f commit ad0beeb
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.kafka.common.requests;

import org.apache.kafka.common.Node;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.message.ProduceResponseData;
import org.apache.kafka.common.message.ProduceResponseData.LeaderIdAndEpoch;
Expand Down Expand Up @@ -72,7 +73,7 @@ public ProduceResponse(ProduceResponseData produceResponseData) {
*/
@Deprecated
public ProduceResponse(Map<TopicPartition, PartitionResponse> responses) {
this(responses, DEFAULT_THROTTLE_TIME);
this(responses, DEFAULT_THROTTLE_TIME, Collections.emptyList());
}

/**
Expand All @@ -83,10 +84,23 @@ public ProduceResponse(Map<TopicPartition, PartitionResponse> responses) {
*/
@Deprecated
public ProduceResponse(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs) {
this(toData(responses, throttleTimeMs));
this(toData(responses, throttleTimeMs, Collections.emptyList()));
}

private static ProduceResponseData toData(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs) {
/**
* Constructor for the latest version
* This is deprecated in favor of using the ProduceResponseData constructor, KafkaApis should switch to that
* in KAFKA-10730
* @param responses Produced data grouped by topic-partition
* @param throttleTimeMs Time in milliseconds the response was throttled
* @param nodeEndpoints List of node endpoints
*/
@Deprecated
public ProduceResponse(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs, List<Node> nodeEndpoints) {
this(toData(responses, throttleTimeMs, nodeEndpoints));
}

private static ProduceResponseData toData(Map<TopicPartition, PartitionResponse> responses, int throttleTimeMs, List<Node> nodeEndpoints) {
ProduceResponseData data = new ProduceResponseData().setThrottleTimeMs(throttleTimeMs);
responses.forEach((tp, response) -> {
ProduceResponseData.TopicProduceResponse tpr = data.responses().find(tp.topic());
Expand All @@ -110,6 +124,12 @@ private static ProduceResponseData toData(Map<TopicPartition, PartitionResponse>
.setBatchIndexErrorMessage(e.message))
.collect(Collectors.toList())));
});
nodeEndpoints.forEach(endpoint -> data.nodeEndpoints()
.add(new ProduceResponseData.NodeEndpoint()
.setNodeId(endpoint.id())
.setHost(endpoint.host())
.setPort(endpoint.port())
.setRack(endpoint.rack())));
return data;
}

Expand Down
53 changes: 51 additions & 2 deletions core/src/main/scala/kafka/server/KafkaApis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,23 @@ class KafkaApis(val requestChannel: RequestChannel,
}
}

case class LeaderNode(leaderId: Int, leaderEpoch: Int, node: Option[Node])

private def getCurrentLeader(tp: TopicPartition, ln: ListenerName): LeaderNode = {
val partitionInfoOrError = replicaManager.getPartitionOrError(tp)
val (leaderId, leaderEpoch) = partitionInfoOrError match {
case Right(x) =>
(x.leaderReplicaIdOpt.getOrElse(-1), x.getLeaderEpoch)
case Left(x) =>
debug(s"Unable to retrieve local leaderId and Epoch with error $x, falling back to metadata cache")
metadataCache.getPartitionInfo(tp.topic, tp.partition) match {
case Some(pinfo) => (pinfo.leader(), pinfo.leaderEpoch())
case None => (-1, -1)
}
}
LeaderNode(leaderId, leaderEpoch, metadataCache.getAliveBrokerNode(leaderId, ln))
}

/**
* Handle a produce request
*/
Expand Down Expand Up @@ -614,6 +631,7 @@ class KafkaApis(val requestChannel: RequestChannel,
val mergedResponseStatus = responseStatus ++ unauthorizedTopicResponses ++ nonExistingTopicResponses ++ invalidRequestResponses
var errorInResponse = false

val nodeEndpoints = new mutable.HashMap[Int, Node]
mergedResponseStatus.forKeyValue { (topicPartition, status) =>
if (status.error != Errors.NONE) {
errorInResponse = true
Expand All @@ -622,6 +640,20 @@ class KafkaApis(val requestChannel: RequestChannel,
request.header.clientId,
topicPartition,
status.error.exceptionName))

if (request.header.apiVersion >= 10) {
status.error match {
case Errors.NOT_LEADER_OR_FOLLOWER =>
val leaderNode = getCurrentLeader(topicPartition, request.context.listenerName)
leaderNode.node.foreach { node =>
nodeEndpoints.put(node.id(), node)
}
status.currentLeader
.setLeaderId(leaderNode.leaderId)
.setLeaderEpoch(leaderNode.leaderEpoch)
case _ =>
}
}
}
}

Expand Down Expand Up @@ -665,7 +697,7 @@ class KafkaApis(val requestChannel: RequestChannel,
requestHelper.sendNoOpResponseExemptThrottle(request)
}
} else {
requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs), None)
requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs, nodeEndpoints.values.toList.asJava), None)
}
}

Expand Down Expand Up @@ -843,6 +875,7 @@ class KafkaApis(val requestChannel: RequestChannel,
.setRecords(unconvertedRecords)
.setPreferredReadReplica(partitionData.preferredReadReplica)
.setDivergingEpoch(partitionData.divergingEpoch)
.setCurrentLeader(partitionData.currentLeader())
}
}
}
Expand All @@ -851,6 +884,7 @@ class KafkaApis(val requestChannel: RequestChannel,
def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
val partitions = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
val reassigningPartitions = mutable.Set[TopicIdPartition]()
val nodeEndpoints = new mutable.HashMap[Int, Node]
responsePartitionData.foreach { case (tp, data) =>
val abortedTransactions = data.abortedTransactions.orElse(null)
val lastStableOffset: Long = data.lastStableOffset.orElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
Expand All @@ -864,6 +898,21 @@ class KafkaApis(val requestChannel: RequestChannel,
.setAbortedTransactions(abortedTransactions)
.setRecords(data.records)
.setPreferredReadReplica(data.preferredReadReplica.orElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID))

if (versionId >= 16) {
data.error match {
case Errors.NOT_LEADER_OR_FOLLOWER | Errors.FENCED_LEADER_EPOCH =>
val leaderNode = getCurrentLeader(tp.topicPartition(), request.context.listenerName)
leaderNode.node.foreach { node =>
nodeEndpoints.put(node.id(), node)
}
partitionData.currentLeader()
.setLeaderId(leaderNode.leaderId)
.setLeaderEpoch(leaderNode.leaderEpoch)
case _ =>
}
}

data.divergingEpoch.ifPresent(partitionData.setDivergingEpoch(_))
partitions.put(tp, partitionData)
}
Expand All @@ -887,7 +936,7 @@ class KafkaApis(val requestChannel: RequestChannel,

// Prepare fetch response from converted data
val response =
FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData)
FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData, nodeEndpoints.values.toList.asJava)
// record the bytes out metrics only when the response is being sent
response.data.responses.forEach { topicResponse =>
topicResponse.partitions.forEach { data =>
Expand Down
Loading

0 comments on commit ad0beeb

Please sign in to comment.