diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala index 2d1449bd96cb..5b9e9bea0a38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -30,33 +30,9 @@ import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { - - import testImplicits._ - +class PythonStreamingDataSourceSimpleSuite extends PythonDataSourceSuiteBase { val waitTimeout = 15.seconds - protected def testDataStreamReaderScript: String = - """ - |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition - | - |class TestDataStreamReader(DataSourceStreamReader): - | current = 0 - | def initialOffset(self): - | return {"offset": {"partition-1": 0}} - | def latestOffset(self): - | self.current += 2 - | return {"offset": {"partition-1": self.current}} - | def partitions(self, start: dict, end: dict): - | start_index = start["offset"]["partition-1"] - | end_index = end["offset"]["partition-1"] - | return [InputPartition(i) for i in range(start_index, end_index)] - | def commit(self, end: dict): - | 1 + 2 - | def read(self, partition): - | yield (partition.value,) - |""".stripMargin - protected def simpleDataStreamReaderScript: String = """ |from pyspark.sql.datasource import SimpleDataSourceStreamReader @@ -94,93 +70,8 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { | return iter([(i, ) for i in range(start_idx, end_idx)]) |""".stripMargin - protected def errorDataStreamReaderScript: String = - """ - |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition - | - |class ErrorDataStreamReader(DataSourceStreamReader): - | def initialOffset(self): - | raise Exception("error reading initial offset") - | def latestOffset(self): - | raise Exception("error reading latest offset") - | def partitions(self, start: dict, end: dict): - | raise Exception("error planning partitions") - | def commit(self, end: dict): - | raise Exception("error committing offset") - | def read(self, partition): - | yield (0, partition.value) - | yield (1, partition.value) - | yield (2, partition.value) - |""".stripMargin - - protected def simpleDataStreamWriterScript: String = - s""" - |import json - |import uuid - |import os - |from pyspark import TaskContext - |from pyspark.sql.datasource import DataSource, DataSourceStreamWriter - |from pyspark.sql.datasource import WriterCommitMessage - | - |class SimpleDataSourceStreamWriter(DataSourceStreamWriter): - | def __init__(self, options, overwrite): - | self.options = options - | self.overwrite = overwrite - | - | def write(self, iterator): - | context = TaskContext.get() - | partition_id = context.partitionId() - | path = self.options.get("path") - | assert path is not None - | output_path = os.path.join(path, f"{partition_id}.json") - | cnt = 0 - | mode = "w" if self.overwrite else "a" - | with open(output_path, mode) as file: - | for row in iterator: - | file.write(json.dumps(row.asDict()) + "\\n") - | return WriterCommitMessage() - | - |class SimpleDataSource(DataSource): - | def schema(self) -> str: - | return "id INT" - | def streamWriter(self, schema, overwrite): - | return SimpleDataSourceStreamWriter(self.options, overwrite) - |""".stripMargin - private val errorDataSourceName = "ErrorDataSource" - test("Test PythonMicroBatchStream") { - assume(shouldTestPandasUDFs) - val dataSourceScript = - s""" - |from pyspark.sql.datasource import DataSource - |$testDataStreamReaderScript - | - |class $dataSourceName(DataSource): - | def streamReader(self, schema): - | return TestDataStreamReader() - |""".stripMargin - val inputSchema = StructType.fromDDL("input BINARY") - - val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) - spark.dataSource.registerPython(dataSourceName, dataSource) - val pythonDs = new PythonDataSourceV2 - pythonDs.setShortName("SimpleDataSource") - val stream = new PythonMicroBatchStream( - pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) - - var startOffset = stream.initialOffset() - assert(startOffset.json == "{\"offset\": {\"partition-1\": 0}}") - for (i <- 1 to 50) { - val endOffset = stream.latestOffset() - assert(endOffset.json == s"""{"offset": {"partition-1": ${2 * i}}}""") - assert(stream.planInputPartitions(startOffset, endOffset).size == 2) - stream.commit(endOffset) - startOffset = endOffset - } - stream.stop() - } - test("SimpleDataSourceStreamReader run query and restart") { assume(shouldTestPandasUDFs) val dataSourceScript = @@ -204,8 +95,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val stopSignal1 = new CountDownLatch(1) - val q1 = df - .writeStream + val q1 = df.writeStream .option("checkpointLocation", checkpointDir.getAbsolutePath) .foreachBatch((df: DataFrame, batchId: Long) => { df.cache() @@ -219,14 +109,14 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { q1.awaitTermination() val stopSignal2 = new CountDownLatch(1) - val q2 = df - .writeStream + val q2 = df.writeStream .option("checkpointLocation", checkpointDir.getAbsolutePath) .foreachBatch((df: DataFrame, batchId: Long) => { df.cache() checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) if (batchId == 20) stopSignal2.countDown() - }).start() + }) + .start() stopSignal2.await() assert(q2.recentProgress.forall(_.numInputRows == 2)) q2.stop() @@ -259,8 +149,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val stopSignal = new CountDownLatch(1) - val q = df - .writeStream + val q = df.writeStream .option("checkpointLocation", checkpointDir.getAbsolutePath) .foreachBatch((df: DataFrame, batchId: Long) => { df.cache() @@ -305,15 +194,13 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { if (i % 2 == 0) { // Remove the last entry of commit log to test replaying microbatch during restart. - val offsetLog = new OffsetSeqLog( - spark, new File(checkpointDir, "offsets").getCanonicalPath) - val commitLog = new CommitLog( - spark, new File(checkpointDir, "commits").getCanonicalPath) + val offsetLog = + new OffsetSeqLog(spark, new File(checkpointDir, "offsets").getCanonicalPath) + val commitLog = new CommitLog(spark, new File(checkpointDir, "commits").getCanonicalPath) commitLog.purgeAfter(offsetLog.getLatest().get._1 - 1) } - val q = df - .writeStream + val q = df.writeStream .option("checkpointLocation", checkpointDir.getAbsolutePath) .format("json") .start(outputDir.getAbsolutePath) @@ -330,8 +217,10 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { // There may be one uncommitted batch that is not recorded in query progress. // The number of batch can be lastBatchId + 1 or lastBatchId + 2. assert(rowCount == 2 * (lastBatchId + 1) || rowCount == 2 * (lastBatchId + 2)) - checkAnswer(spark.read.format("json").load(outputDir.getAbsolutePath), - (0 until rowCount.toInt).map(Row(_))) + checkAnswer( + spark.read.format("json").load(outputDir.getAbsolutePath), + (0 until rowCount.toInt).map(Row(_)) + ) } } @@ -356,35 +245,45 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val pythonDs = new PythonDataSourceV2 pythonDs.setShortName("ErrorDataSource") - def testMicroBatchStreamError(action: String, msg: String) - (func: PythonMicroBatchStream => Unit): Unit = { + def testMicroBatchStreamError(action: String, msg: String)( + func: PythonMicroBatchStream => Unit): Unit = { val stream = new PythonMicroBatchStream( - pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + pythonDs, + errorDataSourceName, + inputSchema, + CaseInsensitiveStringMap.empty() + ) val err = intercept[SparkException] { func(stream) } - checkErrorMatchPVals(err, + checkErrorMatchPVals( + err, errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" - )) + ) + ) assert(err.getMessage.contains(msg)) assert(err.getMessage.contains("ErrorDataSource")) stream.stop() } testMicroBatchStreamError( - "initialOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") { - stream => stream.initialOffset() + "initialOffset", + "[NOT_IMPLEMENTED] initialOffset is not implemented") { + stream => + stream.initialOffset() } // User don't need to implement latestOffset for SimpleDataSourceStreamReader. // The latestOffset method of simple stream reader invokes initialOffset() and read() // So the not implemented method is initialOffset. testMicroBatchStreamError( - "latestOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") { - stream => stream.latestOffset() + "latestOffset", + "[NOT_IMPLEMENTED] initialOffset is not implemented") { + stream => + stream.latestOffset() } } @@ -412,29 +311,116 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val pythonDs = new PythonDataSourceV2 pythonDs.setShortName("ErrorDataSource") - def testMicroBatchStreamError(action: String, msg: String) - (func: PythonMicroBatchStream => Unit): Unit = { + def testMicroBatchStreamError(action: String, msg: String)( + func: PythonMicroBatchStream => Unit): Unit = { val stream = new PythonMicroBatchStream( - pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + pythonDs, + errorDataSourceName, + inputSchema, + CaseInsensitiveStringMap.empty() + ) val err = intercept[SparkException] { func(stream) } - checkErrorMatchPVals(err, + checkErrorMatchPVals( + err, errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" - )) + ) + ) assert(err.getMessage.contains(msg)) assert(err.getMessage.contains("ErrorDataSource")) stream.stop() } - testMicroBatchStreamError( - "latestOffset", "Exception: error reading available data") { - stream => stream.latestOffset() + testMicroBatchStreamError("latestOffset", "Exception: error reading available data") { stream => + stream.latestOffset() } } +} + +class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { + val waitTimeout = 15.seconds + + protected def testDataStreamReaderScript: String = + """ + |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition + | + |class TestDataStreamReader(DataSourceStreamReader): + | current = 0 + | def initialOffset(self): + | return {"offset": {"partition-1": 0}} + | def latestOffset(self): + | self.current += 2 + | return {"offset": {"partition-1": self.current}} + | def partitions(self, start: dict, end: dict): + | start_index = start["offset"]["partition-1"] + | end_index = end["offset"]["partition-1"] + | return [InputPartition(i) for i in range(start_index, end_index)] + | def commit(self, end: dict): + | 1 + 2 + | def read(self, partition): + | yield (partition.value,) + |""".stripMargin + + protected def errorDataStreamReaderScript: String = + """ + |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition + | + |class ErrorDataStreamReader(DataSourceStreamReader): + | def initialOffset(self): + | raise Exception("error reading initial offset") + | def latestOffset(self): + | raise Exception("error reading latest offset") + | def partitions(self, start: dict, end: dict): + | raise Exception("error planning partitions") + | def commit(self, end: dict): + | raise Exception("error committing offset") + | def read(self, partition): + | yield (0, partition.value) + | yield (1, partition.value) + | yield (2, partition.value) + |""".stripMargin + + private val errorDataSourceName = "ErrorDataSource" + + test("Test PythonMicroBatchStream") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + |$testDataStreamReaderScript + | + |class $dataSourceName(DataSource): + | def streamReader(self, schema): + | return TestDataStreamReader() + |""".stripMargin + val inputSchema = StructType.fromDDL("input BINARY") + + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val pythonDs = new PythonDataSourceV2 + pythonDs.setShortName("SimpleDataSource") + val stream = new PythonMicroBatchStream( + pythonDs, + dataSourceName, + inputSchema, + CaseInsensitiveStringMap.empty() + ) + + var startOffset = stream.initialOffset() + assert(startOffset.json == "{\"offset\": {\"partition-1\": 0}}") + for (i <- 1 to 50) { + val endOffset = stream.latestOffset() + assert(endOffset.json == s"""{"offset": {"partition-1": ${2 * i}}}""") + assert(stream.planInputPartitions(startOffset, endOffset).size == 2) + stream.commit(endOffset) + startOffset = endOffset + } + stream.stop() + } test("Read from test data stream source") { assume(shouldTestPandasUDFs) @@ -457,13 +443,16 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val stopSignal = new CountDownLatch(1) - val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => { - // checkAnswer may materialize the dataframe more than once - // Cache here to make sure the numInputRows metrics is consistent. - df.cache() - checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) - if (batchId > 30) stopSignal.countDown() - }).trigger(ProcessingTimeTrigger(0)).start() + val q = df.writeStream + .foreachBatch((df: DataFrame, batchId: Long) => { + // checkAnswer may materialize the dataframe more than once + // Cache here to make sure the numInputRows metrics is consistent. + df.cache() + checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) + if (batchId > 30) stopSignal.countDown() + }) + .trigger(ProcessingTimeTrigger(0)) + .start() stopSignal.await() assert(q.recentProgress.forall(_.numInputRows == 2)) q.stop() @@ -492,13 +481,16 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val stopSignal = new CountDownLatch(1) - val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => { - // checkAnswer may materialize the dataframe more than once - // Cache here to make sure the numInputRows metrics is consistent. - df.cache() - checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) - if (batchId >= 2) stopSignal.countDown() - }).trigger(ProcessingTimeTrigger(20 * 1000)).start() + val q = df.writeStream + .foreachBatch((df: DataFrame, batchId: Long) => { + // checkAnswer may materialize the dataframe more than once + // Cache here to make sure the numInputRows metrics is consistent. + df.cache() + checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) + if (batchId >= 2) stopSignal.countDown() + }) + .trigger(ProcessingTimeTrigger(20 * 1000)) + .start() stopSignal.await() assert(q.recentProgress.forall(_.numInputRows == 2)) q.stop() @@ -547,13 +539,16 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val stopSignal = new CountDownLatch(1) - val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => { - // checkAnswer may materialize the dataframe more than once - // Cache here to make sure the numInputRows metrics is consistent. - df.cache() - checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) - if (batchId > 30) stopSignal.countDown() - }).trigger(ProcessingTimeTrigger(0)).start() + val q = df.writeStream + .foreachBatch((df: DataFrame, batchId: Long) => { + // checkAnswer may materialize the dataframe more than once + // Cache here to make sure the numInputRows metrics is consistent. + df.cache() + checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1))) + if (batchId > 30) stopSignal.countDown() + }) + .trigger(ProcessingTimeTrigger(0)) + .start() stopSignal.await() assert(q.recentProgress.forall(_.numInputRows == 2)) q.stop() @@ -571,13 +566,12 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { | def streamReader(self, schema): | raise Exception("error creating stream reader") |""".stripMargin - val dataSource = createUserDefinedPythonDataSource( - name = dataSourceName, pythonScript = dataSourceScript) + val dataSource = + createUserDefinedPythonDataSource(name = dataSourceName, pythonScript = dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[StreamingQueryException] { - val q = spark.readStream.format(dataSourceName).load() - .writeStream.format("console").start() + val q = spark.readStream.format(dataSourceName).load().writeStream.format("console").start() q.awaitTermination() } assert(err.getErrorClass == "STREAM_FAILED") @@ -622,10 +616,12 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val df = spark.readStream.format(dataSourceName).load() val err = intercept[StreamingQueryException] { - val q = df.writeStream.foreachBatch((df: DataFrame, _: Long) => { - df.count() - () - }).start() + val q = df.writeStream + .foreachBatch((df: DataFrame, _: Long) => { + df.count() + () + }) + .start() q.awaitTermination() } assert(err.getMessage.contains("error reading data")) @@ -652,38 +648,46 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val pythonDs = new PythonDataSourceV2 pythonDs.setShortName("ErrorDataSource") - def testMicroBatchStreamError(action: String, msg: String) - (func: PythonMicroBatchStream => Unit): Unit = { + def testMicroBatchStreamError(action: String, msg: String)( + func: PythonMicroBatchStream => Unit): Unit = { val stream = new PythonMicroBatchStream( - pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + pythonDs, + errorDataSourceName, + inputSchema, + CaseInsensitiveStringMap.empty() + ) val err = intercept[SparkException] { func(stream) } - checkErrorMatchPVals(err, + checkErrorMatchPVals( + err, errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" - )) + ) + ) assert(err.getMessage.contains(msg)) assert(err.getMessage.contains("ErrorDataSource")) stream.stop() } testMicroBatchStreamError( - "initialOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") { - stream => stream.initialOffset() + "initialOffset", + "[NOT_IMPLEMENTED] initialOffset is not implemented") { + stream => + stream.initialOffset() } - testMicroBatchStreamError( - "latestOffset", "[NOT_IMPLEMENTED] latestOffset is not implemented") { - stream => stream.latestOffset() + testMicroBatchStreamError("latestOffset", "[NOT_IMPLEMENTED] latestOffset is not implemented") { + stream => + stream.latestOffset() } val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}") - testMicroBatchStreamError( - "planPartitions", "[NOT_IMPLEMENTED] partitions is not implemented") { - stream => stream.planInputPartitions(offset, offset) + testMicroBatchStreamError("planPartitions", "[NOT_IMPLEMENTED] partitions is not implemented") { + stream => + stream.planInputPartitions(offset, offset) } } @@ -706,40 +710,87 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { pythonDs.setShortName("ErrorDataSource") val offset = PythonStreamingSourceOffset("{\"offset\": 2}") - def testMicroBatchStreamError(action: String, msg: String) - (func: PythonMicroBatchStream => Unit): Unit = { + def testMicroBatchStreamError(action: String, msg: String)( + func: PythonMicroBatchStream => Unit): Unit = { val stream = new PythonMicroBatchStream( - pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + pythonDs, + errorDataSourceName, + inputSchema, + CaseInsensitiveStringMap.empty() + ) val err = intercept[SparkException] { func(stream) } - checkErrorMatchPVals(err, + checkErrorMatchPVals( + err, errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", parameters = Map( "action" -> action, "msg" -> "(.|\\n)*" - )) + ) + ) assert(err.getMessage.contains(msg)) assert(err.getMessage.contains("ErrorDataSource")) stream.stop() } - testMicroBatchStreamError("initialOffset", "error reading initial offset") { - stream => stream.initialOffset() + testMicroBatchStreamError("initialOffset", "error reading initial offset") { stream => + stream.initialOffset() } - testMicroBatchStreamError("latestOffset", "error reading latest offset") { - stream => stream.latestOffset() + testMicroBatchStreamError("latestOffset", "error reading latest offset") { stream => + stream.latestOffset() } - testMicroBatchStreamError("planPartitions", "error planning partitions") { - stream => stream.planInputPartitions(offset, offset) + testMicroBatchStreamError("planPartitions", "error planning partitions") { stream => + stream.planInputPartitions(offset, offset) } - testMicroBatchStreamError("commitSource", "error committing offset") { - stream => stream.commit(offset) + testMicroBatchStreamError("commitSource", "error committing offset") { stream => + stream.commit(offset) } } +} + +class PythonStreamingDataSourceWriteSuite extends PythonDataSourceSuiteBase { + + import testImplicits._ + + val waitTimeout = 15.seconds + + protected def simpleDataStreamWriterScript: String = + s""" + |import json + |import uuid + |import os + |from pyspark import TaskContext + |from pyspark.sql.datasource import DataSource, DataSourceStreamWriter + |from pyspark.sql.datasource import WriterCommitMessage + | + |class SimpleDataSourceStreamWriter(DataSourceStreamWriter): + | def __init__(self, options, overwrite): + | self.options = options + | self.overwrite = overwrite + | + | def write(self, iterator): + | context = TaskContext.get() + | partition_id = context.partitionId() + | path = self.options.get("path") + | assert path is not None + | output_path = os.path.join(path, f"{partition_id}.json") + | cnt = 0 + | mode = "w" if self.overwrite else "a" + | with open(output_path, mode) as file: + | for row in iterator: + | file.write(json.dumps(row.asDict()) + "\\n") + | return WriterCommitMessage() + | + |class SimpleDataSource(DataSource): + | def schema(self) -> str: + | return "id INT" + | def streamWriter(self, schema, overwrite): + | return SimpleDataSourceStreamWriter(self.options, overwrite) + |""".stripMargin Seq("append", "complete").foreach { mode => test(s"data source stream write - $mode mode") {