diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java index baa9b4f35db4c..9e885b091d7bf 100644 --- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java +++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java @@ -823,6 +823,7 @@ public enum LogKeys implements LogKey { TIMEOUT, TIMER, TIMESTAMP, + TIMESTAMP_COLUMN_NAME, TIME_UNITS, TIP, TOKEN, diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala index 02568aa89eb1d..9fcdf1a7d9bf4 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchPartitionReader.scala @@ -19,15 +19,19 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import org.apache.kafka.common.record.TimestampType + import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.connector.metric.CustomTaskMetric import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead +import org.apache.spark.sql.connector.read.streaming.SupportsRealTimeRead.RecordStatus import org.apache.spark.sql.execution.streaming.runtime.{MicroBatchExecution, StreamExecution} -import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer +import org.apache.spark.sql.kafka010.consumer.{KafkaDataConsumer, KafkaDataConsumerIterator} /** A [[InputPartition]] for reading Kafka data in a batch based streaming query. */ private[kafka010] case class KafkaBatchInputPartition( @@ -67,7 +71,8 @@ private case class KafkaBatchPartitionReader( executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, failOnDataLoss: Boolean, - includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging { + includeHeaders: Boolean) + extends SupportsRealTimeRead[InternalRow] with Logging { private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams) @@ -77,6 +82,12 @@ private case class KafkaBatchPartitionReader( private var nextOffset = rangeToRead.fromOffset private var nextRow: UnsafeRow = _ + private var iteratorForRealTimeMode: Option[KafkaDataConsumerIterator] = None + + // Boolean flag that indicates whether we have logged the type of timestamp (i.e. create time, + // log-append time, etc.) for the Kafka source. We log upon reading the first record, and we + // then skip logging for subsequent records. + private var timestampTypeLogged = false override def next(): Boolean = { if (nextOffset < rangeToRead.untilOffset) { @@ -93,6 +104,38 @@ private case class KafkaBatchPartitionReader( } } + override def nextWithTimeout(timeoutMs: java.lang.Long): RecordStatus = { + if (!iteratorForRealTimeMode.isDefined) { + logInfo(s"Getting a new kafka consuming iterator for ${offsetRange.topicPartition} " + + s"starting from ${nextOffset}, timeoutMs ${timeoutMs}") + iteratorForRealTimeMode = Some(consumer.getIterator(nextOffset)) + } + assert(iteratorForRealTimeMode.isDefined) + val nextRecord = iteratorForRealTimeMode.get.nextWithTimeout(timeoutMs) + nextRecord.foreach { record => + + nextRow = unsafeRowProjector(record) + nextOffset = record.offset + 1 + if (record.timestampType() == TimestampType.LOG_APPEND_TIME || + record.timestampType() == TimestampType.CREATE_TIME) { + if (!timestampTypeLogged) { + logInfo(log"Kafka source record timestamp type is " + + log"${MDC(LogKeys.TIMESTAMP_COLUMN_NAME, record.timestampType())}") + timestampTypeLogged = true + } + + RecordStatus.newStatusWithArrivalTimeMs(record.timestamp()) + } else { + RecordStatus.newStatusWithoutArrivalTime(true) + } + } + RecordStatus.newStatusWithoutArrivalTime(nextRecord.isDefined) + } + + override def getOffset(): KafkaSourcePartitionOffset = { + KafkaSourcePartitionOffset(offsetRange.topicPartition, nextOffset) + } + override def get(): UnsafeRow = { assert(nextRow != null) nextRow diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 7449e91230334..828891f0b4983 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -26,7 +26,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP} +import org.apache.spark.internal.LogKeys.{ERROR, OFFSETS, TIP, TOPIC_PARTITION_OFFSET} import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} @@ -60,7 +60,11 @@ private[kafka010] class KafkaMicroBatchStream( metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) - extends SupportsTriggerAvailableNow with ReportsSourceMetrics with MicroBatchStream with Logging { + extends SupportsTriggerAvailableNow + with SupportsRealTimeMode + with ReportsSourceMetrics + with MicroBatchStream + with Logging { private[kafka010] val pollTimeoutMs = options.getLong( KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, @@ -93,6 +97,11 @@ private[kafka010] class KafkaMicroBatchStream( private var isTriggerAvailableNow: Boolean = false + private var inRealTimeMode = false + override def prepareForRealTimeMode(): Unit = { + inRealTimeMode = true + } + /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * called in StreamExecutionThread. Otherwise, interrupting a thread while running @@ -218,6 +227,107 @@ private[kafka010] class KafkaMicroBatchStream( }.toArray } + override def planInputPartitions(start: Offset): Array[InputPartition] = { + // This function is used for real time mode. Trigger restrictions won't be supported. + if (maxOffsetsPerTrigger.isDefined) { + throw new UnsupportedOperationException( + "maxOffsetsPerTrigger is not compatible with real time mode") + } + if (minOffsetPerTrigger.isDefined) { + throw new UnsupportedOperationException( + "minOffsetsPerTrigger is not compatible with real time mode" + ) + } + if (options.containsKey(KafkaSourceProvider.MIN_PARTITIONS_OPTION_KEY)) { + throw new UnsupportedOperationException( + "minpartitions is not compatible with real time mode" + ) + } + if (options.containsKey(KafkaSourceProvider.ENDING_TIMESTAMP_OPTION_KEY)) { + throw new UnsupportedOperationException( + "endingtimestamp is not compatible with real time mode" + ) + } + if (options.containsKey(KafkaSourceProvider.MAX_TRIGGER_DELAY)) { + throw new UnsupportedOperationException( + "maxtriggerdelay is not compatible with real time mode" + ) + } + + // This function is used by Real-time Mode, where we expect 1:1 mapping between a + // topic partition and an input partition. + // We are skipping partition range check for performance reason. We can always try to do + // it in tasks if needed. + val startPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets + + // Here we check previous topic partitions with latest partition offsets to see if we need to + // update the partition list. Here we don't need the updated partition topic to be absolutely + // up to date, because there might already be minutes' delay since new partition is created. + // latestPartitionOffsets should be fetched not long ago anyway. + // If the topic partitions change, we fetch the earliest offsets for all new partitions + // and add them to the list. + assert(latestPartitionOffsets != null, "latestPartitionOffsets should be set in latestOffset") + val latestTopicPartitions = latestPartitionOffsets.keySet + val newStartPartitionOffsets = if (startPartitionOffsets.keySet == latestTopicPartitions) { + startPartitionOffsets + } else { + val newPartitions = latestTopicPartitions.diff(startPartitionOffsets.keySet) + // Instead of fetching earliest offsets, we could fill offset 0 here and avoid this extra + // admin function call. But we consider new partition is rare and getting earliest offset + // aligns with what we do in micro-batch mode and can potentially enable more sanity checks + // in executor side. + val newPartitionOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq) + + assert( + newPartitionOffsets.keys.forall(!startPartitionOffsets.contains(_)), + "startPartitionOffsets should not contain any key in newPartitionOffsets") + + logInfo(log"Partitions added: ${MDC(TOPIC_PARTITION_OFFSET, newPartitionOffsets)}") + // Filter out new partition offsets that are not 0 and log a warning + val nonZeroNewPartitionOffsets = newPartitionOffsets.filter { + case (_, offset) => offset != 0 + } + // Log the non-zero new partition offsets + if (nonZeroNewPartitionOffsets.nonEmpty) { + logWarning(log"new partitions should start from offset 0: " + + log"${MDC(OFFSETS, nonZeroNewPartitionOffsets)}") + nonZeroNewPartitionOffsets.foreach { + case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed", + () => KafkaExceptions.addedPartitionDoesNotStartFromZero(p, o)) + } + } + + val deletedPartitions = startPartitionOffsets.keySet.diff(latestTopicPartitions) + if (deletedPartitions.nonEmpty) { + reportDataLoss( + s"$deletedPartitions are gone. Some data may have been missed", + () => + KafkaExceptions.partitionsDeleted(deletedPartitions, None)) + } + + startPartitionOffsets ++ newPartitionOffsets + } + + newStartPartitionOffsets.keySet.toSeq.map { tp => + val fromOffset = newStartPartitionOffsets(tp) + KafkaBatchInputPartition( + KafkaOffsetRange(tp, fromOffset, Long.MaxValue, preferredLoc = None), + executorKafkaParams, + pollTimeoutMs, + failOnDataLoss, + includeHeaders) + }.toArray + } + + override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = { + val mergedMap = offsets.map { + case KafkaSourcePartitionOffset(p, o) => (p, o) + }.toMap + KafkaSourceOffset(mergedMap) + } + override def createReaderFactory(): PartitionReaderFactory = { KafkaBatchReaderFactory } @@ -235,7 +345,22 @@ private[kafka010] class KafkaMicroBatchStream( override def toString(): String = s"KafkaV2[$kafkaOffsetReader]" override def metrics(latestConsumedOffset: Optional[Offset]): ju.Map[String, String] = { - KafkaMicroBatchStream.metrics(latestConsumedOffset, latestPartitionOffsets) + val reCalculatedLatestPartitionOffsets = + if (inRealTimeMode) { + if (!latestConsumedOffset.isPresent) { + // this means a batch has no end offsets, which should not happen + None + } else { + Some { + kafkaOffsetReader.fetchLatestOffsets( + Some(latestConsumedOffset.get.asInstanceOf[KafkaSourceOffset].partitionToOffsets)) + } + } + } else { + Some(latestPartitionOffsets) + } + + KafkaMicroBatchStream.metrics(latestConsumedOffset, reCalculatedLatestPartitionOffsets) } /** @@ -386,13 +511,14 @@ object KafkaMicroBatchStream extends Logging { */ def metrics( latestConsumedOffset: Optional[Offset], - latestAvailablePartitionOffsets: PartitionOffsetMap): ju.Map[String, String] = { + latestAvailablePartitionOffsets: Option[PartitionOffsetMap]): ju.Map[String, String] = { val offset = Option(latestConsumedOffset.orElse(null)) - if (offset.nonEmpty && latestAvailablePartitionOffsets != null) { + if (offset.nonEmpty && latestAvailablePartitionOffsets.isDefined) { val consumedPartitionOffsets = offset.map(KafkaSourceOffset(_)).get.partitionToOffsets - val offsetsBehindLatest = latestAvailablePartitionOffsets - .map(partitionOffset => partitionOffset._2 - consumedPartitionOffsets(partitionOffset._1)) + val offsetsBehindLatest = latestAvailablePartitionOffsets.get + .map(partitionOffset => partitionOffset._2 - + consumedPartitionOffsets.getOrElse(partitionOffset._1, 0L)) if (offsetsBehindLatest.nonEmpty) { val avgOffsetBehindLatest = offsetsBehindLatest.sum.toDouble / offsetsBehindLatest.size return Map[String, String]( diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala index 2d1125294df27..ce2294c1836f7 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala @@ -63,6 +63,13 @@ private[kafka010] class InternalKafkaConsumer( private[consumer] var kafkaParamsWithSecurity: ju.Map[String, Object] = _ private val consumer = createConsumer() + def poll(pollTimeoutMs: Long): ju.List[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val p = consumer.poll(Duration.ofMillis(pollTimeoutMs)) + val r = p.records(topicPartition) + logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") + r + } + /** * Poll messages from Kafka starting from `offset` and returns a pair of "list of consumer record" * and "offset after poll". The list of consumer record may be empty if the Kafka consumer fetches @@ -131,7 +138,7 @@ private[kafka010] class InternalKafkaConsumer( c } - private def seek(offset: Long): Unit = { + def seek(offset: Long): Unit = { logDebug(s"Seeking to $groupId $topicPartition $offset") consumer.seek(topicPartition, offset) } @@ -228,6 +235,19 @@ private[consumer] case class FetchedRecord( } } +/** + * This class keeps returning the next records. If no new record is available, it will keep + * polling until timeout. It is used by KafkaBatchPartitionReader.nextWithTimeout(), to reduce + * seeking overhead in real time mode. + */ +private[sql] trait KafkaDataConsumerIterator { + /** + * Return the next record + * @return None if no new record is available after `timeoutMs`. + */ + def nextWithTimeout(timeoutMs: Long): Option[ConsumerRecord[Array[Byte], Array[Byte]]] +} + /** * This class helps caller to read from Kafka leveraging consumer pool as well as fetched data pool. * This class throws error when data loss is detected while reading from Kafka. @@ -272,6 +292,82 @@ private[kafka010] class KafkaDataConsumer( // Starting timestamp when the consumer is created. private var startTimestampNano: Long = System.nanoTime() + /** + * Get an iterator that can return the next entry. It is used exclusively for real-time + * mode. + * + * It is called by KafkaBatchPartitionReader.nextWithTimeout(). Unlike get(), there is no + * out-of-bound check in this function. Since there is no endOffset given, we assume anything + * record is valid to return as long as it is at or after `offset`. + * + * @param startOffset, the starting positions to read from, inclusive. + */ + def getIterator(startOffset: Long): KafkaDataConsumerIterator = { + new KafkaDataConsumerIterator { + private var fetchedRecordList + : Option[ju.ListIterator[ConsumerRecord[Array[Byte], Array[Byte]]]] = None + private val consumer = getOrRetrieveConsumer() + private var firstRecord = true + private var _currentOffset: Long = startOffset - 1 + + private def fetchedRecordListHasNext(): Boolean = { + fetchedRecordList.map(_.hasNext).getOrElse(false) + } + + override def nextWithTimeout( + timeoutMs: Long): Option[ConsumerRecord[Array[Byte], Array[Byte]]] = { + var timeLeftMs = timeoutMs + + def timeAndDeductFromTimeLeftMs[T](body: => T): Unit = { + // To reduce timing the same operator twice, we reuse the timing results for + // totalTimeReadNanos and for timeoutMs. + val prevTime = totalTimeReadNanos + timeNanos { + body + } + timeLeftMs -= (totalTimeReadNanos - prevTime) / 1000000 + } + + if (firstRecord) { + timeAndDeductFromTimeLeftMs { + consumer.seek(startOffset) + firstRecord = false + } + } + while (!fetchedRecordListHasNext() && timeLeftMs > 0) { + timeAndDeductFromTimeLeftMs { + try { + val records = consumer.poll(timeLeftMs) + numPolls += 1 + if (!records.isEmpty) { + numRecordsPolled += records.size + fetchedRecordList = Some(records.listIterator) + } + } catch { + case ex: OffsetOutOfRangeException => + if (_currentOffset != -1) { + throw ex + } else { + Thread.sleep(10) // retry until the source partition is populated + assert(startOffset == 0) + consumer.seek(startOffset) + } + } + } + } + if (fetchedRecordListHasNext()) { + totalRecordsRead += 1 + val nextRecord = fetchedRecordList.get.next() + assert(nextRecord.offset > _currentOffset, "Kafka offset should be incremental.") + _currentOffset = nextRecord.offset + Some(nextRecord) + } else { + None + } + } + } + } + /** * Get the record for the given offset if available. * diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 7eabf05235060..f907d2f8e62b0 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1839,20 +1839,20 @@ abstract class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBa val latestOffset = Map[TopicPartition, Long]((topicPartition1, 3L), (topicPartition2, 6L)) // test empty offset. - assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(null), latestOffset).isEmpty) + assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(null), Some(latestOffset)).isEmpty) // test valid offsetsBehindLatest val offset = KafkaSourceOffset( Map[TopicPartition, Long]((topicPartition1, 1L), (topicPartition2, 2L))) assert( - KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), latestOffset) === + KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), Some(latestOffset)) === Map[String, String]( "minOffsetsBehindLatest" -> "2", "maxOffsetsBehindLatest" -> "4", "avgOffsetsBehindLatest" -> "3.0").asJava) // test null latestAvailablePartitionOffsets - assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), null).isEmpty) + assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), None).isEmpty) } } diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala new file mode 100644 index 0000000000000..a359dc355478a --- /dev/null +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeIntegrationSuite.scala @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.nio.file.Files +import java.util.Properties + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +import org.apache.kafka.clients.producer.{KafkaProducer, Producer, ProducerRecord} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.matchers.should.Matchers +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkContext, ThreadAudit} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.execution.streaming.RealTimeTrigger +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{OutputMode, ResultsCollector, StreamingQuery, StreamRealTimeModeE2ESuiteBase, StreamRealTimeModeSuiteBase} +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +class KafkaRealTimeModeE2ESuite extends KafkaSourceTest with StreamRealTimeModeE2ESuiteBase { + + override protected val defaultTrigger: RealTimeTrigger = RealTimeTrigger.apply("5 seconds") + + override protected def createSparkSession = + new TestSparkSession( + new SparkContext( + "local[15]", + "streaming-key-cuj" + ) + ) + + override def beforeEach(): Unit = { + super[KafkaSourceTest].beforeEach() + super[StreamRealTimeModeE2ESuiteBase].beforeEach() + } + + def getKafkaConsumerProperties: Properties = { + val props: Properties = new Properties() + props.put("bootstrap.servers", testUtils.brokerAddress) + props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer") + props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer") + props.put("compression.type", "snappy") + + props + } + + test("Union two kafka streams, for each write to sink") { + var q: StreamingQuery = null + try { + val topic1 = newTopic() + val topic2 = newTopic() + testUtils.createTopic(topic1, partitions = 2) + testUtils.createTopic(topic2, partitions = 2) + + val props: Properties = getKafkaConsumerProperties + val producer1: Producer[String, String] = new KafkaProducer[String, String](props) + val producer2: Producer[String, String] = new KafkaProducer[String, String](props) + + val readStream1 = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic1) + .load() + + val readStream2 = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic2) + .load() + + val df = readStream1 + .union(readStream2) + .selectExpr("CAST(key AS STRING) AS key", "CAST(value AS STRING) AS value") + .selectExpr("key || ',' || value") + .toDF() + + q = runStreamingQuery("union-kafka", df) + + waitForTasksToStart(4) + + val expectedResults = new mutable.ListBuffer[String]() + for (batch <- 0 until 3) { + (1 to 100).foreach(i => { + producer1 + .send( + new ProducerRecord[String, String]( + topic1, + java.lang.Long.toString(i), + s"input1-${batch}-${i}" + ) + ) + .get() + producer2 + .send( + new ProducerRecord[String, String]( + topic2, + java.lang.Long.toString(i), + s"input2-${batch}-${i}" + ) + ) + .get() + }) + producer1.flush() + producer2.flush() + + expectedResults ++= (1 to 100) + .flatMap(v => { + Seq( + s"${v},input1-${batch}-${v}", + s"${v},input2-${batch}-${v}" + ) + }) + .toList + + eventually(timeout(60.seconds)) { + ResultsCollector + .get(sinkName) + .toArray(new Array[String](ResultsCollector.get(sinkName).size())) + .toList + .sorted should equal(expectedResults.sorted) + } + } + } finally { + if (q != null) { + q.stop() + } + } + } +} + + +/** + * Kafka Real-Time Integration test suite. + * Tests with a distributed spark cluster with + * separate executors processes deployed. + */ +class KafkaRealTimeIntegrationSuite + extends KafkaSourceTest + with StreamRealTimeModeSuiteBase + with ThreadAudit + with BeforeAndAfterEach + with Matchers { + + override protected def createSparkSession = + new TestSparkSession( + new SparkContext( + "local-cluster[3, 5, 1024]", // Ensure we have enough for both stages. + "microbatch-context", + sparkConf + .set("spark.sql.testkey", "true") + .set("spark.scheduler.mode", "FAIR") + .set("spark.executor.extraJavaOptions", "-Dio.netty.leakDetection.level=paranoid") + ) + ) + + override def beforeAll(): Unit = { + super.beforeAll() + // testing to make sure the cluster is usable + testUtils.createTopic("_test") + testUtils.sendMessage(new ProducerRecord[String, String]("_test", "", "")) + testUtils.deleteTopic("_test") + logInfo("Kafka cluster setup complete....") + + eventually(timeout(10.seconds)) { + val executors = sparkContext.getExecutorIds() + assert(executors.size == 3, s"executors: ${executors}}") + } + } + + test("e2e stateless") { + var query: StreamingQuery = null + try { + val inputTopic = newTopic() + testUtils.createTopic(inputTopic, partitions = 5) + + val outputTopic = newTopic() + testUtils.createTopic(outputTopic, partitions = 5) + + val props: Properties = new Properties() + + props.put("bootstrap.servers", testUtils.brokerAddress) + props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer"); + props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer"); + props.put("compression.type", "snappy") + + val producer: Producer[String, String] = new KafkaProducer[String, String](props) + + query = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("startingOffsets", "earliest") + .option("subscribe", inputTopic) + .option("kafka.fetch.max.wait.ms", 10) + .load() + .withColumn("value", substring(col("value"), 0, 500 * 1000)) + .withColumn("value", base64(col("value"))) + .withColumn( + "headers", + array( + struct( + lit("source-timestamp") as "key", + unix_millis(col("timestamp")).cast("STRING").cast("BINARY") as "value" + ) + ) + ) + .drop(col("timestamp")) + .writeStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", outputTopic) + .option("checkpointLocation", Files.createTempDirectory("some-prefix").toFile.getName) + .option("kafka.max.block.ms", "100") + .trigger(RealTimeTrigger.apply("5 minutes")) + .outputMode(OutputMode.Update()) + .start() + + waitForTasksToStart(5) + + var expectedResults: ListBuffer[GenericRowWithSchema] = new ListBuffer + for (i <- 0 until 3) { + (1 to 100).foreach(i => { + producer + .send( + new ProducerRecord[String, String]( + inputTopic, + java.lang.Long.toString(i), + s"payload-${i}" + ) + ) + .get() + }) + + producer.flush() + + val kafkaSinkData = spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", outputTopic) + .option("includeHeaders", "true") + .option("startingOffsets", "earliest") + .load() + .withColumn("value", unbase64(col("value")).cast("STRING")) + .withColumn("headers-map", map_from_entries(col("headers"))) + .withColumn("source-timestamp", conv(hex(col("headers-map.source-timestamp")), 16, 10)) + .withColumn("sink-timestamp", unix_millis(col("timestamp"))) + + // Check the answers + val newResults = (1 to 100) + .map(v => { + new GenericRowWithSchema( + Array(s"payload-${v}"), + schema = new StructType().add(StructField("value", StringType)) + ) + }) + .toList + + expectedResults ++= newResults + expectedResults = + expectedResults.sorted((x: GenericRowWithSchema, y: GenericRowWithSchema) => { + x.getString(0).compareTo(y.getString(0)) + }) + + eventually(timeout(1.minute)) { + checkAnswer(kafkaSinkData.select("value"), expectedResults.toSeq) + } + } + } finally { + if (query != null) { + query.stop() + } + } + } +} diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala new file mode 100644 index 0000000000000..83aae64d84f7e --- /dev/null +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala @@ -0,0 +1,681 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.scalatest.matchers.should.Matchers +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkIllegalStateException} +import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{StreamingQuery, Trigger} +import org.apache.spark.sql.streaming.OutputMode.Update +import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.util.SystemClock + +class KafkaRealTimeModeSuite + extends KafkaSourceTest + with Matchers { + + override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds") + + override protected def sparkConf: SparkConf = { + // Should turn to use StreamingShuffleManager when it is ready. + super.sparkConf + .set("spark.databricks.streaming.realTimeMode.enabled", "true") + .set( + SQLConf.STATE_STORE_PROVIDER_CLASS, + classOf[RocksDBStateStoreProvider].getName) + } + + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[8]", // Ensure enough number of cores to ensure concurrent schedule of all tasks. + "streaming-rtm-context", + sparkConf.set("spark.sql.testkey", "true"))) + + + import testImplicits._ + + val sleepOneSec = new ExternalAction() { + override def runAction(): Unit = { + Thread.sleep(1000) + } + } + + var clock = new GlobalSingletonManualClock() + + private def advanceRealTimeClock(timeMs: Int) = new ExternalAction { + override def runAction(): Unit = { + clock.advance(timeMs) + } + + override def toString(): String = { + s"advanceRealTimeClock($timeMs)" + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set( + SQLConf.STREAMING_REAL_TIME_MODE_MIN_BATCH_DURATION, + defaultTrigger.batchDurationMs + ) + } + + override def beforeEach(): Unit = { + super.beforeEach() + GlobalSingletonManualClock.reset() + } + + override def afterEach(): Unit = { + LowLatencyClock.setClock(new SystemClock) + super.afterEach() + } + + def waitUntilBatchStartedOrProcessed(q: StreamingQuery, batchId: Long): Unit = { + eventually(timeout(60.seconds)) { + val tasksRunning = + spark.sparkContext.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum + val lastBatch = { + if (q.lastProgress == null) { + -1 + } else { + q.lastProgress.batchId + } + } + val batchStarted = tasksRunning >= 1 && lastBatch >= batchId - 1 + val batchProcessed = lastBatch >= batchId + assert(batchStarted || batchProcessed, + s"tasksRunning: ${tasksRunning} lastBatch: ${lastBatch}") + } + } + + // A simple unit test that reads from Kakfa source, does a simple map and writes to memory + // sink. + test("simple map") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + testUtils.sendMessages(topic, Array("1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("3"), Some(1)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(reader, Update, sink = new ContinuousMemorySink())( + StartStream(), + CheckAnswerWithTimeout(60000, 2, 3, 4), + sleepOneSec, + sleepOneSec, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("4", "5"), Some(0)) + testUtils.sendMessages(topic, Array("6"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("7"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8), + WaitUntilCurrentBatchProcessed, + StopStream, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("8"), Some(0)) + testUtils.sendMessages(topic, Array("9"), Some(1)) + } + }, + StartStream(), + CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8, 9, 10), + WaitUntilCurrentBatchProcessed) + } + + // A simple unit test that reads from Kakfa source, does a simple map and writes to memory + // sink. Make sure there is no data for a whole batch. Also, after restart the first batch + // has no data. + test("simple map with empty batch") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + val reader = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(reader, Update, sink = new ContinuousMemorySink())( + StartStream(), + WaitUntilBatchProcessed(0), + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("1"), Some(0)) + testUtils.sendMessages(topic, Array("2"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3), + WaitUntilCurrentBatchProcessed, + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("3"), Some(1)) + } + }, + WaitUntilCurrentBatchProcessed, + CheckAnswerWithTimeout(5000, 2, 3, 4), + StopStream, + StartStream(), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("4"), Some(0)) + testUtils.sendMessages(topic, Array("5"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6), + WaitUntilCurrentBatchProcessed + ) + } + + // A simple unit test that reads from Kakfa source, does a simple map and writes to memory + // sink. + test("add partition") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + testUtils.sendMessages(topic, Array("1", "2", "3"), Some(0)) + + val reader = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(reader, Update, sink = new ContinuousMemorySink())( + StartStream(), + CheckAnswerWithTimeout(60000, 2, 3, 4), + sleepOneSec, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.addPartitions(topic, 2) + testUtils.sendMessages(topic, Array("4", "5"), Some(0)) + testUtils.sendMessages(topic, Array("6"), Some(1)) + } + }, + CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.addPartitions(topic, 4) + testUtils.sendMessages(topic, Array("7"), Some(2)) + } + }, + CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7, 8), + WaitUntilCurrentBatchProcessed, + StopStream, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("8"), Some(3)) + testUtils.sendMessages(topic, Array("9"), Some(2)) + } + }, + StartStream(), + CheckAnswerWithTimeout(15000, 2, 3, 4, 5, 6, 7, 8, 9, 10), + WaitUntilCurrentBatchProcessed + ) + } + + test("Real-Time Mode fetches latestOffset again at end of the batch") { + // LowLatencyClock does not affect the wait time of kafka iterator, so advancing the clock + // does not affect the test finish time. The purpose of using it is to make the query start + // time consistent, so the test behaves the same. + LowLatencyClock.setClock(clock) + val topic = newTopic() + testUtils.createTopic(topic, partitions = 1) + + val reader = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + // extra large number to make sure fetch does + // not return within batch duration + .option("kafka.fetch.max.wait.ms", "20000000") + .option("kafka.fetch.min.bytes", "20000000") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(reader, Update, sink = new ContinuousMemorySink())( + StartStream(Trigger.RealTime(10000)), + advanceRealTimeClock(2000), + Execute { q => + waitUntilBatchStartedOrProcessed(q, 0) + testUtils.sendMessages(topic, Array("1"), Some(0)) + }, + advanceRealTimeClock(8000), + WaitUntilBatchProcessed(0), + Execute { q => + val expectedMetrics = Map( + "minOffsetsBehindLatest" -> "1", + "maxOffsetsBehindLatest" -> "1", + "avgOffsetsBehindLatest" -> "1.0", + "estimatedTotalBytesBehindLatest" -> null + ) + eventually(timeout(60.seconds)) { + expectedMetrics.foreach { case (metric, expectedValue) => + assert(q.lastProgress.sources(0).metrics.get(metric) === expectedValue) + } + } + }, + advanceRealTimeClock(2000), + Execute { q => + waitUntilBatchStartedOrProcessed(q, 1) + testUtils.sendMessages(topic, Array("2", "3"), Some(0)) + }, + advanceRealTimeClock(8000), + WaitUntilBatchProcessed(1), + Execute { q => + val expectedMetrics = Map( + "minOffsetsBehindLatest" -> "3", + "maxOffsetsBehindLatest" -> "3", + "avgOffsetsBehindLatest" -> "3.0", + "estimatedTotalBytesBehindLatest" -> null + ) + eventually(timeout(60.seconds)) { + expectedMetrics.foreach { case (metric, expectedValue) => + assert(q.lastProgress.sources(0).metrics.get(metric) === expectedValue) + } + } + } + ) + } + + // Validate the query fails with minOffsetPerTrigger option set. + Seq( + "maxoffsetspertrigger", + "minoffsetspertrigger", + "minpartitions", + "endingtimestamp", + "maxtriggerdelay").foreach { opt => + test(s"$opt incompatible") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + val reader = spark.readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .option(opt, "5") + .load() + testStream(reader, Update, sink = new ContinuousMemorySink())( + StartStream(), + ExpectFailure[UnsupportedOperationException] { (t: Throwable) => { + assert(t.getMessage.toLowerCase().contains(opt)) + } + } + ) + } + } + + test("union 2 dataframes after projection") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + val topic1 = newTopic() + testUtils.createTopic(topic1, partitions = 2) + + testUtils.sendMessages(topic, Array("1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("3"), Some(1)) + + testUtils.sendMessages(topic1, Array("11", "12"), Some(0)) + testUtils.sendMessages(topic1, Array("13"), Some(1)) + + val reader1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + val reader2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic1) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + val unionedReader = reader1.union(reader2) + + testStream(unionedReader, Update, sink = new ContinuousMemorySink())( + StartStream(), + CheckAnswerWithTimeout(60000, 2, 3, 4, 12, 13, 14), + sleepOneSec, + sleepOneSec, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("4", "5"), Some(0)) + testUtils.sendMessages(topic, Array("6"), Some(1)) + testUtils.sendMessages(topic1, Array("14", "15"), Some(0)) + testUtils.sendMessages(topic1, Array("16"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("7"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17, 8), + WaitUntilCurrentBatchProcessed, + StopStream, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("8"), Some(0)) + testUtils.sendMessages(topic, Array("9"), Some(1)) + testUtils.sendMessages(topic1, Array("19"), Some(1)) + } + }, + StartStream(), + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 16, 17, 8, 9, 10, 20), + WaitUntilCurrentBatchProcessed) + } + + test("union 3 dataframes with and without maxPartitions") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + val topic1 = newTopic() + testUtils.createTopic(topic1, partitions = 2) + + val topic2 = newTopic() + testUtils.createTopic(topic2, partitions = 2) + + testUtils.sendMessages(topic, Array("1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("3"), Some(1)) + + testUtils.sendMessages(topic1, Array("11", "12"), Some(0)) + testUtils.sendMessages(topic1, Array("13"), Some(1)) + + testUtils.sendMessages(topic2, Array("21", "22"), Some(0)) + testUtils.sendMessages(topic2, Array("23"), Some(1)) + + val reader1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .option("maxPartitions", "1") + .load() + + val reader2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic1) + .option("startingOffsets", "earliest") + .load() + + val reader3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic2) + .option("startingOffsets", "earliest") + .option("maxPartitions", "3") + .load() + + val unionedReader = reader1.union(reader2).union(reader3) + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(unionedReader, Update, sink = new ContinuousMemorySink())( + StartStream(), + CheckAnswerWithTimeout(10000, 2, 3, 4, 12, 13, 14, 22, 23, 24), + sleepOneSec, + sleepOneSec, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("4"), Some(0)) + testUtils.sendMessages(topic, Array("5"), Some(1)) + testUtils.sendMessages(topic1, Array("14", "15"), Some(0)) + testUtils.sendMessages(topic2, Array("24"), Some(0)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15, 16, 25), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("6"), Some(1)) + testUtils.sendMessages(topic2, Array("25"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15, 16, 25, 7, 26), + WaitUntilCurrentBatchProcessed, + StopStream, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic1, Array("16"), Some(1)) + } + }, + StartStream(), + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 22, 23, 24, 5, 6, 15, 16, 25, 7, 26, 17), + WaitUntilCurrentBatchProcessed) + } + + test("self union workaround") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + testUtils.sendMessages(topic, Array("1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("3"), Some(1)) + + val reader1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + + + val reader2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + + val unionedReader = reader1.union(reader2) + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(unionedReader, Update, sink = new ContinuousMemorySink())( + StartStream(), + CheckAnswerWithTimeout(60000, 2, 3, 4, 2, 3, 4), + sleepOneSec, + sleepOneSec, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("4", "5"), Some(0)) + testUtils.sendMessages(topic, Array("6"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("7"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7, 8, 8), + WaitUntilCurrentBatchProcessed, + StopStream, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("8"), Some(0)) + testUtils.sendMessages(topic, Array("9"), Some(1)) + } + }, + StartStream(), + CheckAnswerWithTimeout(5000, 2, 3, 4, 2, 3, 4, 5, 6, 7, 5, 6, 7, 8, 8, 9, 10, 9, 10), + WaitUntilCurrentBatchProcessed) + } + + test("union 2 different sources - Kafka and LowLatencyMemoryStream") { + import testImplicits._ + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + val memoryStreamRead = LowLatencyMemoryStream[String](2) + + testUtils.sendMessages(topic, Array("1", "2"), Some(0)) + testUtils.sendMessages(topic, Array("3"), Some(1)) + memoryStreamRead.addData("11", "12", "13") + + val reader1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + + + val reader2 = memoryStreamRead.toDF() + .selectExpr("CAST(value AS STRING)") + .as[String] + + val unionedReader = reader1.union(reader2) + .map(_.toInt) + .map(_ + 1) + + testStream(unionedReader, Update, sink = new ContinuousMemorySink())( + StartStream(), + CheckAnswerWithTimeout(60000, 2, 3, 4, 12, 13, 14), + sleepOneSec, + sleepOneSec, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("4", "5"), Some(0)) + testUtils.sendMessages(topic, Array("6"), Some(1)) + memoryStreamRead.addData("14") + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15), + WaitUntilCurrentBatchProcessed, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("7"), Some(1)) + } + }, + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 8), + WaitUntilCurrentBatchProcessed, + StopStream, + new ExternalAction() { + override def runAction(): Unit = { + testUtils.sendMessages(topic, Array("8"), Some(0)) + testUtils.sendMessages(topic, Array("9"), Some(1)) + memoryStreamRead.addData("15", "16", "17") + } + }, + StartStream(), + CheckAnswerWithTimeout(5000, 2, 3, 4, 12, 13, 14, 5, 6, 7, 15, 8, 9, 10, 16, 17, 18), + WaitUntilCurrentBatchProcessed) + } + + test("self union - not allowed") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 2) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + .load() + + val unionedReader = reader.union(reader) + .selectExpr("CAST(value AS STRING)") + .as[String] + .map(_.toInt) + .map(_ + 1) + + testStream(unionedReader, Update, sink = new ContinuousMemorySink())( + StartStream(), + ExpectFailure[SparkIllegalStateException] { ex => + checkErrorMatchPVals( + ex.asInstanceOf[SparkIllegalStateException], + "STREAMING_REAL_TIME_MODE.IDENTICAL_SOURCES_IN_UNION_NOT_SUPPORTED", + parameters = + Map("sources" -> "(?s).*") + ) + } + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala index 5bb01bdea26e5..9199580f65872 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamRealTimeModeSuiteBase.scala @@ -17,11 +17,18 @@ package org.apache.spark.sql.streaming +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} + +import scala.collection.mutable + import org.scalatest.matchers.should.Matchers +import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.ForeachWriter import org.apache.spark.sql.execution.datasources.v2.LowLatencyClock -import org.apache.spark.sql.execution.streaming.RealTimeTrigger +import org.apache.spark.sql.execution.streaming.{LowLatencyMemoryStream, RealTimeTrigger} +import org.apache.spark.sql.execution.streaming.runtime.StreamingQueryWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock import org.apache.spark.sql.test.TestSparkSession @@ -43,6 +50,124 @@ trait StreamRealTimeModeSuiteBase extends StreamTest with Matchers { "local[10]", // Ensure enough number of cores to ensure concurrent schedule of all tasks. "streaming-rtm-context", sparkConf.set("spark.sql.testkey", "true"))) + + /** + * Should only be used in real-time mode where the batch duration is long enough to ensure + * eventually does not skip the batch due to long refresh interval. + */ + def waitForTasksToStart(numTasks: Int): Unit = { + eventually(timeout(60.seconds)) { + val tasksRunning = spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum + assert(tasksRunning == numTasks, s"tasksRunning: ${tasksRunning}") + } + } +} + +/** + * Must be a singleton object to ensure serializable when used in ForeachWriter. + * Users must make sure different test suites use different sink names to avoid race conditions. + */ +object ResultsCollector extends ConcurrentHashMap[String, ConcurrentLinkedQueue[String]] { + def reset(): Unit = { + clear() + } +} + +/** + * Base class that contains helper methods to test Real-Time Mode streaming queries. + * + * The general procedure to use this suite is as follows: + * 1. Call createMemoryStream to create a memory stream with manual clock. + * 2. Call runStreamingQuery to start a streaming query with custom logic. + * 3. Call processBatches to add data to the memory stream and validate results. + * + * It uses foreach to collect results into [[ResultsCollector]]. It also tests whether + * results are emitted in real-time by having longer batch durations than the waiting time. + */ +trait StreamRealTimeModeE2ESuiteBase extends StreamRealTimeModeSuiteBase { + import testImplicits._ + + override protected val defaultTrigger = RealTimeTrigger.apply("300 seconds") + + protected final def sinkName: String = getClass.getName + "Sink" + + override def beforeEach(): Unit = { + super.beforeEach() + ResultsCollector.reset() + } + + // Create a ForeachWriter that collects results into ResultsCollector. + def foreachWriter(sinkName: String): ForeachWriter[String] = new ForeachWriter[String] { + override def open(partitionId: Long, epochId: Long): Boolean = { + true + } + + override def process(value: String): Unit = { + val collector = + ResultsCollector.computeIfAbsent(sinkName, (_) => new ConcurrentLinkedQueue[String]()) + collector.add(value) + } + + override def close(errorOrNull: Throwable): Unit = {} + } + + def createMemoryStream(numPartitions: Int = 5) + : (LowLatencyMemoryStream[(String, Int)], GlobalSingletonManualClock) = { + val clock = new GlobalSingletonManualClock() + LowLatencyClock.setClock(clock) + val read = LowLatencyMemoryStream[(String, Int)](numPartitions) + (read, clock) + } + + def runStreamingQuery(queryName: String, df: org.apache.spark.sql.DataFrame): StreamingQuery = { + df.as[String] + .writeStream + .outputMode(OutputMode.Update()) + .foreach(foreachWriter(sinkName)) + .queryName(queryName) + .trigger(defaultTrigger) + .start() + } + + // Add test data to the memory source and validate results + def processBatches( + query: StreamingQuery, + read: LowLatencyMemoryStream[(String, Int)], + clock: GlobalSingletonManualClock, + numRowsPerBatch: Int, + numBatches: Int, + expectedResultsGenerator: (String, Int) => Array[String]): Unit = { + val expectedResults = mutable.ListBuffer[String]() + for (i <- 0 until numBatches) { + for (key <- List("a", "b", "c")) { + for (j <- 1 to numRowsPerBatch) { + val value = i * numRowsPerBatch + j + read.addData((key, value)) + expectedResults ++= expectedResultsGenerator(key, value) + } + } + + eventually(timeout(60.seconds)) { + ResultsCollector + .get(sinkName) + .toArray(new Array[String](ResultsCollector.get(sinkName).size())) + .toList + .sorted should equal(expectedResults.sorted) + } + + clock.advance(defaultTrigger.batchDurationMs) + + eventually(timeout(60.seconds)) { + query + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + .getLatestExecutionContext() + .batchId should be(i + 1) + query.lastProgress.sources(0).numInputRows should be(numRowsPerBatch * 3) + } + } + } } abstract class StreamRealTimeModeManualClockSuiteBase extends StreamRealTimeModeSuiteBase { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 3ff8cab64d652..e57c4e1e665cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -284,6 +284,8 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with case class WaitUntilBatchProcessed(batchId: Long) extends StreamAction with StreamMustBeRunning + case object WaitUntilCurrentBatchProcessed extends StreamAction with StreamMustBeRunning + /** * Signals that a failure is expected and should not kill the test. * @@ -659,6 +661,20 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with throw currentStream.exception.get } + case WaitUntilCurrentBatchProcessed => + if (currentStream.exception.isDefined) { + throw currentStream.exception.get + } + val currBatch = currentStream.commitLog.getLatestBatchId().getOrElse(-1L) + eventually("Current batch never finishes") { + assert(currentStream.commitLog.getLatestBatchId() != None + && currentStream.commitLog.getLatestBatchId().get > currBatch) + + // See WaitUntilBatchProcessed for an explanation of why we wait for the progress + val latestProgressBatchId = + currentStream.recentProgress.lastOption.map(_.batchId).getOrElse(-1L) + assert(latestProgressBatchId >= currBatch) + } case StopStream => verify(currentStream != null, "can not stop a stream that is not running") try failAfter(streamingTimeout) {