Skip to content

Commit 928f253

Browse files
jerrypengviirya
authored andcommitted
[SPARK-54027] Kafka Source RTM support
### What changes were proposed in this pull request? Add support for Real-time Mode in the Kafka Source. Which means KafkaMicroBatchStream needs to implement the SupportsRealTimeMode interface and the KakfaPartitionBatchReader needs to extend SupportRealTimeRead interface. ### Why are the changes needed? So that Kafka source and sink can be used by Real-time Mode queries ### Does this PR introduce _any_ user-facing change? Yes, Kafka source and sink can be used by Real-time Mode queries ### How was this patch tested? Many tests added ### Was this patch authored or co-authored using generative AI tooling? No Closes #52729 from jerrypeng/SPARK-54027-int. Authored-by: Jerry Peng <jerry.peng@databricks.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
1 parent 7dd973d commit 928f253

File tree

9 files changed

+1396
-15
lines changed

9 files changed

+1396
-15
lines changed

common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,7 @@ public enum LogKeys implements LogKey {
824824
TIMEOUT,
825825
TIMER,
826826
TIMESTAMP,
827+
TIMESTAMP_COLUMN_NAME,
827828
TIME_UNITS,
828829
TIP,
829830
TOKEN,

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@ package org.apache.spark.sql.kafka010
1919

2020
import java.{util => ju}
2121

22+
import org.apache.kafka.common.record.TimestampType
23+
2224
import org.apache.spark.TaskContext
23-
import org.apache.spark.internal.Logging
25+
import org.apache.spark.internal.{Logging, LogKeys}
2426
import org.apache.spark.internal.LogKeys._
2527
import org.apache.spark.sql.catalyst.InternalRow
2628
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
2729
import org.apache.spark.sql.connector.metric.CustomTaskMetric
2830
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
31+
import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead
32+
import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus
2933
import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution}
30-
import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
34+
import org.apache.spark.sql.kafka010.consumer.{KafkaDataConsumer, KafkaDataConsumerIterator}
3135

3236
/** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */
3337
private[kafka010] case class KafkaBatchInputPartition(
@@ -67,7 +71,8 @@ private case class KafkaBatchPartitionReader(
6771
executorKafkaParams: ju.Map[String, Object],
6872
pollTimeoutMs: Long,
6973
failOnDataLoss: Boolean,
70-
includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {
74+
includeHeaders: Boolean)
75+
extends SupportsRealTimeRead[InternalRow] with Logging {
7176

7277
private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)
7378

@@ -77,6 +82,12 @@ private case class KafkaBatchPartitionReader(
7782

7883
private var nextOffset = rangeToRead.fromOffset
7984
private var nextRow: UnsafeRow = _
85+
private var iteratorForRealTimeMode: Option[KafkaDataConsumerIterator] = None
86+
87+
// Boolean flag that indicates whether we have logged the type of timestamp (i.e. create time,
88+
// log-append time, etc.) for the Kafka source. We log upon reading the first record, and we
89+
// then skip logging for subsequent records.
90+
private var timestampTypeLogged = false
8091

8192
override def next(): Boolean = {
8293
if (nextOffset < rangeToRead.untilOffset) {
@@ -93,6 +104,38 @@ private case class KafkaBatchPartitionReader(
93104
}
94105
}
95106

107+
override def nextWithTimeout(timeoutMs: java.lang.Long): RecordStatus = {
108+
if (!iteratorForRealTimeMode.isDefined) {
109+
logInfo(s"Getting a new kafka consuming iterator for ${offsetRange.topicPartition} " +
110+
s"starting from ${nextOffset}, timeoutMs ${timeoutMs}")
111+
iteratorForRealTimeMode = Some(consumer.getIterator(nextOffset))
112+
}
113+
assert(iteratorForRealTimeMode.isDefined)
114+
val nextRecord = iteratorForRealTimeMode.get.nextWithTimeout(timeoutMs)
115+
nextRecord.foreach { record =>
116+
117+
nextRow = unsafeRowProjector(record)
118+
nextOffset = record.offset + 1
119+
if (record.timestampType() == TimestampType.LOG_APPEND_TIME ||
120+
record.timestampType() == TimestampType.CREATE_TIME) {
121+
if (!timestampTypeLogged) {
122+
logInfo(log"Kafka source record timestamp type is " +
123+
log"${MDC(LogKeys.TIMESTAMP_COLUMN_NAME, record.timestampType())}")
124+
timestampTypeLogged = true
125+
}
126+
127+
RecordStatus.newStatusWithArrivalTimeMs(record.timestamp())
128+
} else {
129+
RecordStatus.newStatusWithoutArrivalTime(true)
130+
}
131+
}
132+
RecordStatus.newStatusWithoutArrivalTime(nextRecord.isDefined)
133+
}
134+
135+
override def getOffset(): KafkaSourcePartitionOffset = {
136+
KafkaSourcePartitionOffset(offsetRange.topicPartition, nextOffset)
137+
}
138+
96139
override def get(): UnsafeRow = {
97140
assert(nextRow != null)
98141
nextRow

connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.kafka.common.TopicPartition
2626

2727
import org.apache.spark.SparkEnv
2828
import org.apache.spark.internal.Logging
29-
import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP}
29+
import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP, TOPIC_PARTITION_OFFSET}
3030
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
3131
import org.apache.spark.sql.SparkSession
3232
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
@@ -60,7 +60,11 @@ private[kafka010] class KafkaMicroBatchStream(
6060
metadataPath: String,
6161
startingOffsets: KafkaOffsetRangeLimit,
6262
failOnDataLoss: Boolean)
63-
extends SupportsTriggerAvailableNow with ReportsSourceMetrics with MicroBatchStream with Logging {
63+
extends SupportsTriggerAvailableNow
64+
with SupportsRealTimeMode
65+
with ReportsSourceMetrics
66+
with MicroBatchStream
67+
with Logging {
6468

6569
private[kafka010] val pollTimeoutMs = options.getLong(
6670
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
@@ -93,6 +97,11 @@ private[kafka010] class KafkaMicroBatchStream(
9397

9498
private var isTriggerAvailableNow: Boolean = false
9599

100+
private var inRealTimeMode = false
101+
override def prepareForRealTimeMode(): Unit = {
102+
inRealTimeMode = true
103+
}
104+
96105
/**
97106
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
98107
* called in StreamExecutionThread. Otherwise, interrupting a thread while running
@@ -218,6 +227,107 @@ private[kafka010] class KafkaMicroBatchStream(
218227
}.toArray
219228
}
220229

230+
override def planInputPartitions(start: Offset): Array[InputPartition] = {
231+
// This function is used for real time mode. Trigger restrictions won't be supported.
232+
if (maxOffsetsPerTrigger.isDefined) {
233+
throw new UnsupportedOperationException(
234+
"maxOffsetsPerTrigger is not compatible with real time mode")
235+
}
236+
if (minOffsetPerTrigger.isDefined) {
237+
throw new UnsupportedOperationException(
238+
"minOffsetsPerTrigger is not compatible with real time mode"
239+
)
240+
}
241+
if (options.containsKey(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) {
242+
throw new UnsupportedOperationException(
243+
"minpartitions is not compatible with real time mode"
244+
)
245+
}
246+
if (options.containsKey(KafkaSourceProvider.ENDING_TIMESTAMP_OPTION_KEY)) {
247+
throw new UnsupportedOperationException(
248+
"endingtimestamp is not compatible with real time mode"
249+
)
250+
}
251+
if (options.containsKey(KafkaSourceProvider.MAX_TRIGGER_DELAY)) {
252+
throw new UnsupportedOperationException(
253+
"maxtriggerdelay is not compatible with real time mode"
254+
)
255+
}
256+
257+
// This function is used by Real-time Mode, where we expect 1:1 mapping between a
258+
// topic partition and an input partition.
259+
// We are skipping partition range check for performance reason. We can always try to do
260+
// it in tasks if needed.
261+
val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets
262+
263+
// Here we check previous topic partitions with latest partition offsets to see if we need to
264+
// update the partition list. Here we don't need the updated partition topic to be absolutely
265+
// up to date, because there might already be minutes' delay since new partition is created.
266+
// latestPartitionOffsets should be fetched not long ago anyway.
267+
// If the topic partitions change, we fetch the earliest offsets for all new partitions
268+
// and add them to the list.
269+
assert(latestPartitionOffsets != null, "latestPartitionOffsets should be set in latestOffset")
270+
val latestTopicPartitions = latestPartitionOffsets.keySet
271+
val newStartPartitionOffsets = if (startPartitionOffsets.keySet == latestTopicPartitions) {
272+
startPartitionOffsets
273+
} else {
274+
val newPartitions = latestTopicPartitions.diff(startPartitionOffsets.keySet)
275+
// Instead of fetching earliest offsets, we could fill offset 0 here and avoid this extra
276+
// admin function call. But we consider new partition is rare and getting earliest offset
277+
// aligns with what we do in micro-batch mode and can potentially enable more sanity checks
278+
// in executor side.
279+
val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
280+
281+
assert(
282+
newPartitionOffsets.keys.forall(!startPartitionOffsets.contains(_)),
283+
"startPartitionOffsets should not contain any key in newPartitionOffsets")
284+
285+
logInfo(log"Partitions added: ${MDC(TOPIC_PARTITION_OFFSET, newPartitionOffsets)}")
286+
// Filter out new partition offsets that are not 0 and log a warning
287+
val nonZeroNewPartitionOffsets = newPartitionOffsets.filter {
288+
case (_, offset) => offset != 0
289+
}
290+
// Log the non-zero new partition offsets
291+
if (nonZeroNewPartitionOffsets.nonEmpty) {
292+
logWarning(log"new partitions should start from offset 0: " +
293+
log"${MDC(OFFSETS, nonZeroNewPartitionOffsets)}")
294+
nonZeroNewPartitionOffsets.foreach {
295+
case (p, o) =>
296+
reportDataLoss(
297+
s"Added partition $p starts from $o instead of 0. Some data may have been missed",
298+
() => KafkaExceptions.addedPartitionDoesNotStartFromZero(p, o))
299+
}
300+
}
301+
302+
val deletedPartitions = startPartitionOffsets.keySet.diff(latestTopicPartitions)
303+
if (deletedPartitions.nonEmpty) {
304+
reportDataLoss(
305+
s"$deletedPartitions are gone. Some data may have been missed",
306+
() =>
307+
KafkaExceptions.partitionsDeleted(deletedPartitions, None))
308+
}
309+
310+
startPartitionOffsets ++ newPartitionOffsets
311+
}
312+
313+
newStartPartitionOffsets.keySet.toSeq.map { tp =>
314+
val fromOffset = newStartPartitionOffsets(tp)
315+
KafkaBatchInputPartition(
316+
KafkaOffsetRange(tp, fromOffset, Long.MaxValue, preferredLoc = None),
317+
executorKafkaParams,
318+
pollTimeoutMs,
319+
failOnDataLoss,
320+
includeHeaders)
321+
}.toArray
322+
}
323+
324+
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
325+
val mergedMap = offsets.map {
326+
case KafkaSourcePartitionOffset(p, o) => (p, o)
327+
}.toMap
328+
KafkaSourceOffset(mergedMap)
329+
}
330+
221331
override def createReaderFactory(): PartitionReaderFactory = {
222332
KafkaBatchReaderFactory
223333
}
@@ -235,7 +345,22 @@ private[kafka010] class KafkaMicroBatchStream(
235345
override def toString(): String = s"KafkaV2[$kafkaOffsetReader]"
236346

237347
override def metrics(latestConsumedOffset: Optional[Offset]): ju.Map[String, String] = {
238-
KafkaMicroBatchStream.metrics(latestConsumedOffset, latestPartitionOffsets)
348+
val reCalculatedLatestPartitionOffsets =
349+
if (inRealTimeMode) {
350+
if (!latestConsumedOffset.isPresent) {
351+
// this means a batch has no end offsets, which should not happen
352+
None
353+
} else {
354+
Some {
355+
kafkaOffsetReader.fetchLatestOffsets(
356+
Some(latestConsumedOffset.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets))
357+
}
358+
}
359+
} else {
360+
Some(latestPartitionOffsets)
361+
}
362+
363+
KafkaMicroBatchStream.metrics(latestConsumedOffset, reCalculatedLatestPartitionOffsets)
239364
}
240365

241366
/**
@@ -386,13 +511,14 @@ object KafkaMicroBatchStream extends Logging {
386511
*/
387512
def metrics(
388513
latestConsumedOffset: Optional[Offset],
389-
latestAvailablePartitionOffsets: PartitionOffsetMap): ju.Map[String, String] = {
514+
latestAvailablePartitionOffsets: Option[PartitionOffsetMap]): ju.Map[String, String] = {
390515
val offset = Option(latestConsumedOffset.orElse(null))
391516

392-
if (offset.nonEmpty && latestAvailablePartitionOffsets != null) {
517+
if (offset.nonEmpty && latestAvailablePartitionOffsets.isDefined) {
393518
val consumedPartitionOffsets = offset.map(KafkaSourceOffset(_)).get.partitionToOffsets
394-
val offsetsBehindLatest = latestAvailablePartitionOffsets
395-
.map(partitionOffset => partitionOffset._2 - consumedPartitionOffsets(partitionOffset._1))
519+
val offsetsBehindLatest = latestAvailablePartitionOffsets.get
520+
.map(partitionOffset => partitionOffset._2 -
521+
consumedPartitionOffsets.getOrElse(partitionOffset._1, 0L))
396522
if (offsetsBehindLatest.nonEmpty) {
397523
val avgOffsetBehindLatest = offsetsBehindLatest.sum.toDouble / offsetsBehindLatest.size
398524
return Map[String, String](

0 commit comments

Comments
 (0)