diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala index 8e29e38b2a64..966dcc6f5e3c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala @@ -40,7 +40,7 @@ private[kafka010] class KafkaBatchWrite( validateQuery(schema.toAttributes, producerParams, topic) - override def createBatchWriterFactory(): KafkaBatchWriterFactory = + override def createBatchWriterFactory(numPartitions: Int): KafkaBatchWriterFactory = KafkaBatchWriterFactory(topic, producerParams, schema) override def commit(messages: Array[WriterCommitMessage]): Unit = {} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index 2b50b771e694..cfef9674c9b1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -41,7 +41,7 @@ private[kafka010] class KafkaStreamingWrite( validateQuery(schema.toAttributes, producerParams, topic) - override def createStreamingWriterFactory(): KafkaStreamWriterFactory = + override def createStreamingWriterFactory(numPartitions: Int): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java index 37c5539d2518..911d06a9474f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java @@ -45,8 +45,10 @@ public interface BatchWrite { * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. + * + * @param numPartitions The number of partitions of the RDD that is going to be written. */ - DataWriterFactory createBatchWriterFactory(); + DataWriterFactory createBatchWriterFactory(int numPartitions); /** * Returns whether Spark should use the commit coordinator to ensure that at most one task for diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java index 0821b3489165..9770420f00c4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java @@ -48,8 +48,10 @@ public interface StreamingWrite { * * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. + * + * @param numPartitions The number of partitions of the RDD that is going to be written. */ - StreamingDataWriterFactory createStreamingWriterFactory(); + StreamingDataWriterFactory createStreamingWriterFactory(int numPartitions); /** * Commits this writing job for the specified epoch with a list of commit messages. The commit diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index 414f9d583486..4f732a59a30f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -98,6 +98,8 @@ class InMemoryTable( new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { private var writer: BatchWrite = Append + private var queryIdProvided = false + private var inputDataSchemaProvided = false override def truncate(): WriteBuilder = { assert(writer == Append) @@ -117,12 +119,32 @@ class InMemoryTable( this } - override def buildForBatch(): BatchWrite = writer + override def withQueryId(queryId: String): WriteBuilder = { + assert(!queryIdProvided, "queryId provided twice") + queryIdProvided = true + this + } + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + assert(!queryIdProvided, "schema provided twice") + inputDataSchemaProvided = true + this + } + + override def buildForBatch(): BatchWrite = { + assert( + inputDataSchemaProvided, + "Input data schema wasn't provided before calling buildForBatch") + assert( + queryIdProvided, + "Query id wasn't provided before calling buildForBatch") + writer + } } } private abstract class TestBatchWrite extends BatchWrite { - override def createBatchWriterFactory(): DataWriterFactory = { + override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = { BufferedRowsWriterFactory } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 3f4f29c3e135..1d682c482f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -58,7 +58,7 @@ private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate } private[noop] object NoopBatchWrite extends BatchWrite { - override def createBatchWriterFactory(): DataWriterFactory = NoopWriterFactory + override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = NoopWriterFactory override def commit(messages: Array[WriterCommitMessage]): Unit = {} override def abort(messages: Array[WriterCommitMessage]): Unit = {} } @@ -74,7 +74,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { } private[noop] object NoopStreamingWrite extends StreamingWrite { - override def createStreamingWriterFactory(): StreamingDataWriterFactory = + override def createStreamingWriterFactory(numPartitions: Int): StreamingDataWriterFactory = NoopStreamingDataWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala index e7d9a247533c..b24764c986a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -44,7 +44,7 @@ class FileBatchWrite( committer.abortJob(job) } - override def createBatchWriterFactory(): DataWriterFactory = { + override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = { FileWriterFactory(description, committer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 9f4392da6ab4..2ac9aaa1d7f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -353,8 +353,8 @@ trait V2TableWriteExec extends UnaryExecNode { override def output: Seq[Attribute] = Nil protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = { - val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator + val rdd = query.execute() // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single // partition rdd to make sure we at least set up one write task to write the metadata. @@ -365,6 +365,8 @@ trait V2TableWriteExec extends UnaryExecNode { } val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length) val totalNumRowsAccumulator = new LongAccumulator() + val writerFactory = batchWrite.createBatchWriterFactory( + rddWithNonEmptyPartitions.partitions.length) logInfo(s"Start processing data source write support: $batchWrite. " + s"The input RDD has ${messages.length} partitions.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index d4e522562e91..7de37868e057 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -38,8 +38,9 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = write.createStreamingWriterFactory() - val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) + val queryRdd = query.execute() + val writerFactory = write.createStreamingWriterFactory(queryRdd.partitions.length) + val rdd = new ContinuousWriteRDD(queryRdd, writerFactory) logInfo(s"Start processing data source write support: $write. " + s"The input RDD has ${rdd.partitions.length} partitions.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala index 6afb811a4d99..8b9a2964c046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala @@ -38,7 +38,8 @@ class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap) assert(SparkSession.getActiveSession.isDefined) protected val spark = SparkSession.getActiveSession.get - def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory + def createStreamingWriterFactory(numPartitions: Int): StreamingDataWriterFactory = + PackedRowWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index bae7fa7d0735..bcda81229f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -72,7 +72,8 @@ case class ForeachWriterTable[T]( override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + override def createStreamingWriterFactory( + numPartitions: Int): StreamingDataWriterFactory = { val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 5f12832cd255..caca1a6e71a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -36,8 +36,8 @@ class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends B writeSupport.abort(eppchId, messages) } - override def createBatchWriterFactory(): DataWriterFactory = { - new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) + override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = { + new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory(numPartitions)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index 51ab5ce3578a..057826f08260 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -140,7 +140,7 @@ class MemoryStreamingWrite( val sink: MemorySink, schema: StructType, needTruncate: Boolean) extends StreamingWrite { - override def createStreamingWriterFactory: MemoryWriterFactory = { + override def createStreamingWriterFactory(numPartitions: Int): MemoryWriterFactory = { MemoryWriterFactory(schema) }