diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index be13cbc51a9d3..0c28fd64f7956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} @@ -69,7 +69,8 @@ object FileFormatWriter extends Logging { val bucketSpec: Option[BucketSpec], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long) + val maxRecordsPerFile: Long, + val outputOrdering: Seq[SortOrder]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -126,7 +127,8 @@ object FileFormatWriter extends Logging { path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) - .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + outputOrdering = queryExecution.executedPlan.outputOrdering ) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { @@ -369,7 +371,78 @@ object FileFormatWriter extends Logging { context = taskAttemptContext) } + // Returns the partition path given a partition key. + private val getPartitionStringFunc = UnsafeProjection.create( + Seq(Concat(partitionStringExpression)), description.partitionColumns) + + // Returns the data columns to be written given an input row + private val getOutputRow = UnsafeProjection.create( + description.dataColumns, description.allColumns) + override def execute(iter: Iterator[InternalRow]): Set[String] = { + val outputOrderingExprs = description.outputOrdering.map(_.child) + val sortedByPartitionCols = + if (description.partitionColumns.length > outputOrderingExprs.length) { + false + } else { + description.partitionColumns.zip(outputOrderingExprs).forall { + case (partitionCol, outputOrderExpr) => partitionCol.semanticEquals(outputOrderExpr) + } + } + + if (sortedByPartitionCols && bucketIdExpression.isEmpty) { + // If the input data is sorted by partition columns and no bucketing is specified, + // we don't need to sort the data by partition columns anymore. + + val getPartitioningKey = UnsafeProjection.create( + description.partitionColumns, description.allColumns) + + // If anything below fails, we should abort the task. + var recordsInFile: Long = 0L + var fileCounter = 0 + var currentKey: UnsafeRow = null + val updatedPartitions = mutable.Set[String]() + while (iter.hasNext) { + val currentRow = iter.next() + val nextKey = getPartitioningKey(currentRow).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + // See a new key - write to a new partition (new file). + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + recordsInFile = 0 + fileCounter = 0 + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + val partitionPath = getPartitionStringFunc(currentKey).getString(0) + if (partitionPath.nonEmpty) { + updatedPartitions.add(partitionPath) + } + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + recordsInFile = 0 + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + } + + currentWriter.write(getOutputRow(currentRow)) + recordsInFile += 1 + } + releaseResources() + updatedPartitions.toSet + } else { + executeWithSort(iter) + } + } + + private def executeWithSort(iter: Iterator[InternalRow]): Set[String] = { // We should first sort by partition columns, then bucket id, and finally sorting columns. val sortingExpressions: Seq[Expression] = description.partitionColumns ++ bucketIdExpression ++ sortColumns @@ -381,14 +454,6 @@ object FileFormatWriter extends Logging { case _ => StructField("bucketId", IntegerType, nullable = false) }) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create( - description.dataColumns, description.allColumns) - - // Returns the partition path given a partition key. - val getPartitionStringFunc = UnsafeProjection.create( - Seq(Concat(partitionStringExpression)), description.partitionColumns) - // Sorts the data before write, so that we only need one writer at the same time. val sorter = new UnsafeKVExternalSorter( sortingKeySchema, @@ -404,6 +469,8 @@ object FileFormatWriter extends Logging { sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } + val sortedIterator = sorter.sortedIterator() + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { identity } else { @@ -412,8 +479,6 @@ object FileFormatWriter extends Logging { }) } - val sortedIterator = sorter.sortedIterator() - // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7a..c76c38c07160d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -487,6 +487,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("SPARK-19352: Keep sort order of rows after external sorter when writing") { + spark.stop() + // Explicitly set memory configuration to force `UnsafeKVExternalSorter` to spill to files + // when inserting data. + val newSpark = SparkSession.builder() + .master("local") + .appName("test") + .config("spark.buffer.pageSize", "16b") + .config("spark.testing.memory", "1400") + .config("spark.memory.fraction", "0.1") + .config("spark.shuffle.sort.initialBufferSize", "2") + .config("spark.memory.offHeap.enabled", "false") + .getOrCreate() + withTempPath { path => + val tempDir = path.getCanonicalPath + val df = newSpark.range(100) + .select($"id", explode(array(col("id") + 1, col("id") + 2, col("id") + 3)).as("value")) + .repartition($"id") + .sortWithinPartitions($"id", $"value".desc).toDF() + + df.write + .partitionBy("id") + .parquet(tempDir) + + val dfReadIn = newSpark.read.parquet(tempDir).select("id", "value") + checkAnswer(df.filter("id = 65"), dfReadIn.filter("id = 65")) + } + newSpark.stop() + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema =