Skip to content

Commit 652dca2

Browse files
committed
[SPARK-19256][SQL] Move bucketing constraints out of FileFormatWriter into RunnableCommand
1 parent 2250cb7 commit 652dca2

File tree

6 files changed

+69
-53
lines changed

6 files changed

+69
-53
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ import org.apache.hadoop.conf.Configuration
2121

2222
import org.apache.spark.SparkContext
2323
import org.apache.spark.sql.{Row, SparkSession}
24-
import org.apache.spark.sql.catalyst.expressions.Attribute
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
2525
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
26+
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
2627
import org.apache.spark.sql.execution.SparkPlan
2728
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
2829
import org.apache.spark.sql.execution.datasources.FileFormatWriter
@@ -60,5 +61,9 @@ trait DataWritingCommand extends Command {
6061
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
6162
}
6263

64+
def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution)
65+
66+
def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
67+
6368
def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
6469
}

sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD
2323
import org.apache.spark.sql.{Row, SparkSession}
2424
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2525
import org.apache.spark.sql.catalyst.errors.TreeNodeException
26-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
26+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder}
2727
import org.apache.spark.sql.catalyst.plans.QueryPlan
2828
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
29+
import org.apache.spark.sql.catalyst.plans.physical.Distribution
2930
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
3031
import org.apache.spark.sql.execution.debug._
3132
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -112,6 +113,10 @@ case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
112113

113114
override def nodeName: String = "Execute " + cmd.nodeName
114115

116+
override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution
117+
118+
override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering
119+
115120
override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray
116121

117122
override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ object FileFormatWriter extends Logging {
109109
outputSpec: OutputSpec,
110110
hadoopConf: Configuration,
111111
partitionColumns: Seq[Attribute],
112-
bucketSpec: Option[BucketSpec],
112+
bucketIdExpression: Option[Expression],
113113
statsTrackers: Seq[WriteJobStatsTracker],
114114
options: Map[String, String])
115115
: Set[String] = {
@@ -122,17 +122,6 @@ object FileFormatWriter extends Logging {
122122
val partitionSet = AttributeSet(partitionColumns)
123123
val dataColumns = outputSpec.outputColumns.filterNot(partitionSet.contains)
124124

125-
val bucketIdExpression = bucketSpec.map { spec =>
126-
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
127-
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
128-
// guarantee the data distribution is same between shuffle and bucketed data source, which
129-
// enables us to only shuffle one side when join a bucketed table and a normal one.
130-
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
131-
}
132-
val sortColumns = bucketSpec.toSeq.flatMap {
133-
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
134-
}
135-
136125
val caseInsensitiveOptions = CaseInsensitiveMap(options)
137126

138127
// Note: prepareWrite has side effect. It sets "job".
@@ -156,40 +145,14 @@ object FileFormatWriter extends Logging {
156145
statsTrackers = statsTrackers
157146
)
158147

159-
// We should first sort by partition columns, then bucket id, and finally sorting columns.
160-
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
161-
// the sort order doesn't matter
162-
val actualOrdering = plan.outputOrdering.map(_.child)
163-
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
164-
false
165-
} else {
166-
requiredOrdering.zip(actualOrdering).forall {
167-
case (requiredOrder, childOutputOrder) =>
168-
requiredOrder.semanticEquals(childOutputOrder)
169-
}
170-
}
171-
172148
SQLExecution.checkSQLExecutionId(sparkSession)
173149

174150
// This call shouldn't be put into the `try` block below because it only initializes and
175151
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
176152
committer.setupJob(job)
177153

178154
try {
179-
val rdd = if (orderingMatched) {
180-
plan.execute()
181-
} else {
182-
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
183-
// the physical plan may have different attribute ids due to optimizer removing some
184-
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
185-
val orderingExpr = requiredOrdering
186-
.map(SortOrder(_, Ascending))
187-
.map(BindReferences.bindReference(_, outputSpec.outputColumns))
188-
SortExec(
189-
orderingExpr,
190-
global = false,
191-
child = plan).execute()
192-
}
155+
val rdd = plan.execute()
193156
val ret = new Array[WriteTaskResult](rdd.partitions.length)
194157
sparkSession.sparkContext.runJob(
195158
rdd,
@@ -202,7 +165,7 @@ object FileFormatWriter extends Logging {
202165
committer,
203166
iterator = iter)
204167
},
205-
0 until rdd.partitions.length,
168+
rdd.partitions.indices,
206169
(index, res: WriteTaskResult) => {
207170
committer.onTaskCommit(res.commitMsg)
208171
ret(index) = res
@@ -521,18 +484,18 @@ object FileFormatWriter extends Logging {
521484
var recordsInFile: Long = 0L
522485
var fileCounter = 0
523486
val updatedPartitions = mutable.Set[String]()
524-
var currentPartionValues: Option[UnsafeRow] = None
487+
var currentPartitionValues: Option[UnsafeRow] = None
525488
var currentBucketId: Option[Int] = None
526489

527490
for (row <- iter) {
528491
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
529492
val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None
530493

531-
if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
494+
if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) {
532495
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
533-
if (isPartitioned && currentPartionValues != nextPartitionValues) {
534-
currentPartionValues = Some(nextPartitionValues.get.copy())
535-
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
496+
if (isPartitioned && currentPartitionValues != nextPartitionValues) {
497+
currentPartitionValues = Some(nextPartitionValues.get.copy())
498+
statsTrackers.foreach(_.newPartition(currentPartitionValues.get))
536499
}
537500
if (isBucketed) {
538501
currentBucketId = nextBucketId
@@ -543,7 +506,7 @@ object FileFormatWriter extends Logging {
543506
fileCounter = 0
544507

545508
releaseResources()
546-
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
509+
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
547510
} else if (desc.maxRecordsPerFile > 0 &&
548511
recordsInFile >= desc.maxRecordsPerFile) {
549512
// Exceeded the threshold in terms of the number of records per file.
@@ -554,7 +517,7 @@ object FileFormatWriter extends Logging {
554517
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
555518

556519
releaseResources()
557-
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
520+
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
558521
}
559522
val outputRow = getOutputRow(row)
560523
currentWriter.write(outputRow)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ import org.apache.spark.internal.io.FileCommitProtocol
2525
import org.apache.spark.sql._
2626
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition}
2727
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
28-
import org.apache.spark.sql.catalyst.expressions.Attribute
28+
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression,
29+
SortOrder}
2930
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
31+
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
3032
import org.apache.spark.sql.execution.SparkPlan
3133
import org.apache.spark.sql.execution.command._
3234
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -150,6 +152,10 @@ case class InsertIntoHadoopFsRelationCommand(
150152
}
151153
}
152154

155+
val partitionSet = AttributeSet(partitionColumns)
156+
val dataColumns = query.output.filterNot(partitionSet.contains)
157+
val bucketIdExpression = getBucketIdExpression(dataColumns)
158+
153159
val updatedPartitionPaths =
154160
FileFormatWriter.write(
155161
sparkSession = sparkSession,
@@ -160,7 +166,7 @@ case class InsertIntoHadoopFsRelationCommand(
160166
qualifiedOutputPath.toString, customPartitionLocations, outputColumns),
161167
hadoopConf = hadoopConf,
162168
partitionColumns = partitionColumns,
163-
bucketSpec = bucketSpec,
169+
bucketIdExpression = bucketIdExpression,
164170
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
165171
options = options)
166172

@@ -184,6 +190,43 @@ case class InsertIntoHadoopFsRelationCommand(
184190
Seq.empty[Row]
185191
}
186192

193+
private def getBucketIdExpression(dataColumns: Seq[Attribute]): Option[Expression] = {
194+
bucketSpec.map { spec =>
195+
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
196+
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
197+
// guarantee the data distribution is same between shuffle and bucketed data source, which
198+
// enables us to only shuffle one side when join a bucketed table and a normal one.
199+
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
200+
}
201+
}
202+
203+
/**
204+
* How is `requiredOrdering` determined ?
205+
*
206+
* table type | requiredOrdering
207+
* -----------------+-------------------------------------------------
208+
* normal table | partition columns
209+
* bucketed table | (partition columns + bucketId + sort columns)
210+
* -----------------+-------------------------------------------------
211+
*/
212+
override def requiredOrdering: Seq[Seq[SortOrder]] = {
213+
val sortExpressions = bucketSpec match {
214+
case Some(spec) =>
215+
val partitionSet = AttributeSet(partitionColumns)
216+
val dataColumns = query.output.filterNot(partitionSet.contains)
217+
val bucketIdExpression = getBucketIdExpression(dataColumns)
218+
val sortColumns = spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
219+
partitionColumns ++ bucketIdExpression ++ sortColumns
220+
221+
case _ => partitionColumns
222+
}
223+
if (sortExpressions.nonEmpty) {
224+
Seq(sortExpressions.map(SortOrder(_, Ascending)))
225+
} else {
226+
Seq.fill(children.size)(Nil)
227+
}
228+
}
229+
187230
/**
188231
* Deletes all partition files that match the specified static prefix. Partitions with custom
189232
* locations are also cleared based on the custom locations map given to this class.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class FileStreamSink(
128128
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output),
129129
hadoopConf = hadoopConf,
130130
partitionColumns = partitionColumns,
131-
bucketSpec = None,
131+
bucketIdExpression = None,
132132
statsTrackers = Nil,
133133
options = options)
134134
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand {
8383
FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, allColumns),
8484
hadoopConf = hadoopConf,
8585
partitionColumns = partitionAttributes,
86-
bucketSpec = None,
86+
bucketIdExpression = None,
8787
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
8888
options = Map.empty)
8989
}

0 commit comments

Comments
 (0)