Skip to content

Commit f3dba5e

Browse files
committed
moving away from withNumPartitions
1 parent c6b85f9 commit f3dba5e

File tree

14 files changed

+36
-56
lines changed

14 files changed

+36
-56
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ private[kafka010] class KafkaBatchWrite(
4040

4141
validateQuery(schema.toAttributes, producerParams, topic)
4242

43-
override def createBatchWriterFactory(): KafkaBatchWriterFactory =
43+
override def createBatchWriterFactory(numPartitions: Int): KafkaBatchWriterFactory =
4444
KafkaBatchWriterFactory(topic, producerParams, schema)
4545

4646
override def commit(messages: Array[WriterCommitMessage]): Unit = {}

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ private[kafka010] class KafkaStreamingWrite(
4141

4242
validateQuery(schema.toAttributes, producerParams, topic)
4343

44-
override def createStreamingWriterFactory(): KafkaStreamWriterFactory =
44+
override def createStreamingWriterFactory(numPartitions: Int): KafkaStreamWriterFactory =
4545
KafkaStreamWriterFactory(topic, producerParams, schema)
4646

4747
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/BatchWrite.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ public interface BatchWrite {
4545
*
4646
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
4747
* submitted.
48+
*
49+
* @param numPartitions The number of partitions of the RDD that is going to be written.
4850
*/
49-
DataWriterFactory createBatchWriterFactory();
51+
DataWriterFactory createBatchWriterFactory(int numPartitions);
5052

5153
/**
5254
* Returns whether Spark should use the commit coordinator to ensure that at most one task for

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/WriteBuilder.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,6 @@ default WriteBuilder withInputDataSchema(StructType schema) {
5555
return this;
5656
}
5757

58-
/**
59-
* Passes the number of partitions of the input data from Spark to data source.
60-
*
61-
* @return a new builder with the `schema`. By default it returns `this`, which means the given
62-
* `numPartitions` is ignored. Please override this method to take the `numPartitions`.
63-
*/
64-
default WriteBuilder withNumPartitions(int numPartitions) {
65-
return this;
66-
}
67-
6858
/**
6959
* Returns a {@link BatchWrite} to write data to batch source. By default this method throws
7060
* exception, data sources must overwrite this method to provide an implementation, if the

sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ public interface StreamingWrite {
4848
*
4949
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
5050
* submitted.
51+
*
52+
* @param numPartitions The number of partitions of the RDD that is going to be written.
5153
*/
52-
StreamingDataWriterFactory createStreamingWriterFactory();
54+
StreamingDataWriterFactory createStreamingWriterFactory(int numPartitions);
5355

5456
/**
5557
* Commits this writing job for the specified epoch with a list of commit messages. The commit

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ class InMemoryTable(
9898

9999
new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite {
100100
private var writer: BatchWrite = Append
101-
private var numPartitionsProvided = false
102101
private var queryIdProvided = false
103102
private var inputDataSchemaProvided = false
104103

@@ -120,12 +119,6 @@ class InMemoryTable(
120119
this
121120
}
122121

123-
override def withNumPartitions(numPartitions: Int): WriteBuilder = {
124-
assert(!numPartitionsProvided, "numPartitions provided twice")
125-
numPartitionsProvided = true
126-
this
127-
}
128-
129122
override def withQueryId(queryId: String): WriteBuilder = {
130123
assert(!queryIdProvided, "queryId provided twice")
131124
queryIdProvided = true
@@ -145,16 +138,13 @@ class InMemoryTable(
145138
assert(
146139
queryIdProvided,
147140
"Query id wasn't provided before calling buildForBatch")
148-
assert(
149-
numPartitionsProvided,
150-
"Number of partitions schema wasn't provided before calling buildForBatch")
151141
writer
152142
}
153143
}
154144
}
155145

156146
private abstract class TestBatchWrite extends BatchWrite {
157-
override def createBatchWriterFactory(): DataWriterFactory = {
147+
override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = {
158148
BufferedRowsWriterFactory
159149
}
160150

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate
5858
}
5959

6060
private[noop] object NoopBatchWrite extends BatchWrite {
61-
override def createBatchWriterFactory(): DataWriterFactory = NoopWriterFactory
61+
override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = NoopWriterFactory
6262
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
6363
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
6464
}
@@ -74,7 +74,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] {
7474
}
7575

7676
private[noop] object NoopStreamingWrite extends StreamingWrite {
77-
override def createStreamingWriterFactory(): StreamingDataWriterFactory =
77+
override def createStreamingWriterFactory(numPartitions: Int): StreamingDataWriterFactory =
7878
NoopStreamingDataWriterFactory
7979
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
8080
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class FileBatchWrite(
4444
committer.abortJob(job)
4545
}
4646

47-
override def createBatchWriterFactory(): DataWriterFactory = {
47+
override def createBatchWriterFactory(numPartitions: Int): DataWriterFactory = {
4848
FileWriterFactory(description, committer)
4949
}
5050
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ case class CreateTableAsSelectExec(
8686
case table: SupportsWrite =>
8787
val writeBuilder = table.newWriteBuilder(writeOptions)
8888
.withInputDataSchema(schema)
89-
.withNumPartitions(rdd.getNumPartitions)
9089
.withQueryId(UUID.randomUUID().toString)
9190

9291
writeBuilder match {
@@ -182,7 +181,6 @@ case class ReplaceTableAsSelectExec(
182181
case table: SupportsWrite =>
183182
val writeBuilder = table.newWriteBuilder(writeOptions)
184183
.withInputDataSchema(schema)
185-
.withNumPartitions(rdd.getNumPartitions)
186184
.withQueryId(UUID.randomUUID().toString)
187185

188186
writeBuilder match {
@@ -334,13 +332,11 @@ case class WriteToDataSourceV2Exec(
334332
trait BatchWriteHelper {
335333
def table: SupportsWrite
336334
def query: SparkPlan
337-
def rdd: RDD[InternalRow]
338335
def writeOptions: CaseInsensitiveStringMap
339336

340337
def newWriteBuilder(): WriteBuilder = {
341338
table.newWriteBuilder(writeOptions)
342339
.withInputDataSchema(query.schema)
343-
.withNumPartitions(rdd.getNumPartitions)
344340
.withQueryId(UUID.randomUUID().toString)
345341
}
346342
}
@@ -351,38 +347,36 @@ trait BatchWriteHelper {
351347
trait V2TableWriteExec extends UnaryExecNode {
352348
def query: SparkPlan
353349

354-
lazy val rdd: RDD[InternalRow] = {
355-
val tempRdd = query.execute()
356-
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
357-
// partition rdd to make sure we at least set up one write task to write the metadata.
358-
if (tempRdd.partitions.length == 0) {
359-
sparkContext.parallelize(Array.empty[InternalRow], 1)
360-
} else {
361-
tempRdd
362-
}
363-
}
364-
365350
var commitProgress: Option[StreamWriterCommitProgress] = None
366351

367352
override def child: SparkPlan = query
368353
override def output: Seq[Attribute] = Nil
369354

370355
protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
371356
val useCommitCoordinator = batchWrite.useCommitCoordinator
372-
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
373-
val totalNumRowsAccumulator = new LongAccumulator()
374357

375-
val writerFactory = batchWrite.createBatchWriterFactory()
358+
val rdd = query.execute()
359+
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
360+
// partition rdd to make sure we at least set up one write task to write the metadata.
361+
val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
362+
sparkContext.parallelize(Array.empty[InternalRow], 1)
363+
} else {
364+
rdd
365+
}
366+
val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length)
367+
val totalNumRowsAccumulator = new LongAccumulator()
368+
val writerFactory = batchWrite.createBatchWriterFactory(
369+
rddWithNonEmptyPartitions.partitions.length)
376370

377371
logInfo(s"Start processing data source write support: $batchWrite. " +
378372
s"The input RDD has ${messages.length} partitions.")
379373

380374
try {
381375
sparkContext.runJob(
382-
rdd,
376+
rddWithNonEmptyPartitions,
383377
(context: TaskContext, iter: Iterator[InternalRow]) =>
384378
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
385-
rdd.partitions.indices,
379+
rddWithNonEmptyPartitions.partitions.indices,
386380
(index, result: DataWritingSparkTaskResult) => {
387381
val commitMessage = result.writerCommitMessage
388382
messages(index) = commitMessage
@@ -488,7 +482,6 @@ private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1W
488482
case table: SupportsWrite =>
489483
val writeBuilder = table.newWriteBuilder(writeOptions)
490484
.withInputDataSchema(query.schema)
491-
.withNumPartitions(rdd.getNumPartitions)
492485
.withQueryId(UUID.randomUUID().toString)
493486

494487
val writtenRows = writeBuilder match {

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPl
3838
override def output: Seq[Attribute] = Nil
3939

4040
override protected def doExecute(): RDD[InternalRow] = {
41-
val writerFactory = write.createStreamingWriterFactory()
42-
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
41+
val queryRdd = query.execute()
42+
val writerFactory = write.createStreamingWriterFactory(queryRdd.partitions.length)
43+
val rdd = new ContinuousWriteRDD(queryRdd, writerFactory)
4344

4445
logInfo(s"Start processing data source write support: $write. " +
4546
s"The input RDD has ${rdd.partitions.length} partitions.")

0 commit comments

Comments
 (0)