From 3c040b664b7aeb0d1ee78272f79140a34ec30ef6 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 28 Jan 2017 00:45:11 +0000 Subject: [PATCH 1/3] Keep sort order of rows after external sorter when writing. --- .../sql/execution/UnsafeKVExternalSorter.java | 33 +++++++- .../datasources/FileFormatWriter.scala | 81 +++++++++++++++---- .../UnsafeKVExternalSorterSuite.scala | 52 +++++++++++- .../datasources/FileSourceStrategySuite.scala | 30 +++++++ 4 files changed, 172 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index ee5bcfd02c79e..50cac27c5c43c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -19,6 +19,10 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.List; + +import scala.collection.JavaConverters; +import scala.collection.Seq; import com.google.common.annotations.VisibleForTesting; @@ -26,6 +30,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -58,7 +63,7 @@ public UnsafeKVExternalSorter( long pageSizeBytes, long numElementsForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, null); + numElementsForSpillThreshold, null, null); } public UnsafeKVExternalSorter( @@ -69,14 +74,34 @@ public UnsafeKVExternalSorter( long pageSizeBytes, long numElementsForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { + this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, + numElementsForSpillThreshold, map, null); + } + + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + SerializerManager serializerManager, + long pageSizeBytes, + long numElementsForSpillThreshold, + @Nullable BytesToBytesMap map, + @Nullable List ordering) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; final TaskContext taskContext = TaskContext.get(); prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); - BaseOrdering ordering = GenerateOrdering.create(keySchema); - KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + KVComparator recordComparator = null; + if (ordering == null) { + recordComparator = new KVComparator(GenerateOrdering.create(keySchema), keySchema.length()); + } else { + Seq orderingSeq = + JavaConverters.collectionAsScalaIterableConverter(ordering).asScala().toSeq(); + recordComparator = new KVComparator((BaseOrdering)GenerateOrdering.generate(orderingSeq), + ordering.size()); + } boolean canUseRadixSort = keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)); @@ -137,7 +162,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - new KVComparator(ordering, keySchema.length()), + recordComparator, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 16c5193eda8df..8a2db6d4d0c54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} +import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.hadoop.conf.Configuration @@ -68,7 +69,8 @@ object FileFormatWriter extends Logging { val bucketSpec: Option[BucketSpec], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long) + val maxRecordsPerFile: Long, + val orderingInPartition: Seq[SortOrder]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -125,7 +127,8 @@ object FileFormatWriter extends Logging { path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) - .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + orderingInPartition = queryExecution.executedPlan.outputOrdering ) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { @@ -368,17 +371,58 @@ object FileFormatWriter extends Logging { } override def execute(iter: Iterator[InternalRow]): Set[String] = { - // We should first sort by partition columns, then bucket id, and finally sorting columns. + // If there is sort ordering in the data, we need to keep the ordering. + val orderingExpressions: Seq[Expression] = if (description.orderingInPartition.isEmpty) { + Nil + } else { + description.orderingInPartition.map(_.child) + } + + // We should first sort by partition columns, then bucket id, then sort ordering in the data, + // and finally sorting columns. val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ sortColumns + description.partitionColumns ++ bucketIdExpression ++ orderingExpressions ++ sortColumns val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) + val bucketIdExprIndex = + sortingExpressions.length - sortColumns.length - orderingExpressions.length - 1 + + val sortingKeySchema = StructType(sortingExpressions.zipWithIndex.map { case (e, index) => + e match { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id and + // sorting order's children expressions. + case _ if index == bucketIdExprIndex => + StructField("bucketId", IntegerType, nullable = false) + case _ if index > bucketIdExprIndex => + StructField(s"_sortOrder_$index", e.dataType, e.nullable) + } }) + val beginSortingExpr = + sortingExpressions.length - sortColumns.length - orderingExpressions.length + val recordSortingOrder = + if (description.orderingInPartition.isEmpty) { + null + } else { + sortingExpressions.zipWithIndex.map { case (field, ordinal) => + if (ordinal < beginSortingExpr || + ordinal > beginSortingExpr + orderingExpressions.length) { + // For partition column, bucket id and sort by columns, we sort by ascending. + SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) + } else { + // For the sort ordering of data, we need to keep its sort direction and + // null ordering. + val direction = + description.orderingInPartition(ordinal - beginSortingExpr).direction + val nullOrdering = + description.orderingInPartition(ordinal - beginSortingExpr).nullOrdering + SortOrder(BoundReference(ordinal, field.dataType, nullable = true), + direction, nullOrdering) + } + }.asJava + } + // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create( description.dataColumns, description.allColumns) @@ -395,20 +439,25 @@ object FileFormatWriter extends Logging { SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + null, + recordSortingOrder) while (iter.hasNext) { val currentRow = iter.next() sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + val getBucketingKey: InternalRow => InternalRow = + if (sortColumns.isEmpty && orderingExpressions.isEmpty) { + identity + } else { + val bucketingKeyExprs = + sortingExpressions.dropRight(sortColumns.length + orderingExpressions.length) + UnsafeProjection.create(bucketingKeyExprs.zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } val sortedIterator = sorter.sortedIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 3d869c77e9608..06243072f3119 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution import java.util.Properties +import scala.collection.JavaConverters._ import scala.util.Random import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Descending, InterpretedOrdering, SortOrder, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -110,7 +111,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { valueSchema: StructType, inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, - spill: Boolean): Unit = { + spill: Boolean, + sortOrdering: java.util.List[SortOrder] = null): Unit = { val memoryManager = new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) @@ -125,7 +127,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD) + pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, + null, sortOrdering) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => @@ -145,7 +148,11 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { } sorter.cleanupResources() - val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + val keyOrdering = if (sortOrdering == null) { + InterpretedOrdering.forSchema(keySchema.map(_.dataType)) + } else { + new InterpretedOrdering(sortOrdering.asScala) + } val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) val kvOrdering = new Ordering[(InternalRow, InternalRow)] { override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { @@ -204,4 +211,41 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } + + test("kv sorting with records that exceed page size: with specified sort order") { + val pageSize = 128 + + val keySchema = StructType(StructField("a", BinaryType) :: StructField("b", BinaryType) :: Nil) + val valueSchema = StructType(StructField("c", BinaryType) :: Nil) + val keyExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val valueExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + val keyConverter = UnsafeProjection.create(keySchema) + val valueConverter = UnsafeProjection.create(valueSchema) + + val rand = new Random() + val inputData = Seq.fill(1024) { + val kBytes1 = new Array[Byte](rand.nextInt(pageSize)) + val kBytes2 = new Array[Byte](rand.nextInt(pageSize)) + val vBytes = new Array[Byte](rand.nextInt(pageSize)) + rand.nextBytes(kBytes1) + rand.nextBytes(kBytes2) + rand.nextBytes(vBytes) + val k = + keyConverter(keyExternalConverter.apply(Row(kBytes1, kBytes2)).asInstanceOf[InternalRow]) + val v = valueConverter(valueExternalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } + + val sortOrder = SortOrder(BoundReference(0, BinaryType, nullable = true), Ascending) :: + SortOrder(BoundReference(1, BinaryType, nullable = true), Descending) :: Nil + + testKVSorter( + keySchema, + valueSchema, + inputData, + pageSize, + spill = true, + sortOrder.asJava + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7a..3756b7ef9d85f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -487,6 +487,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("SPARK-19352: Keep sort order of rows after external sorter when writing") { + spark.stop() + // Explicitly set memory configuration to force `UnsafeKVExternalSorter` to spill to files + // when inserting data. + val newSpark = SparkSession.builder() + .master("local") + .appName("test") + .config("spark.buffer.pageSize", "16b") + .config("spark.testing.memory", "1400") + .config("spark.memory.fraction", "0.1") + .config("spark.shuffle.sort.initialBufferSize", "2") + .config("spark.memory.offHeap.enabled", "false") + .getOrCreate() + withTempPath { path => + val tempDir = path.getCanonicalPath + val df = newSpark.range(100) + .select($"id", explode(array(col("id") + 1, col("id") + 2, col("id") + 3)).as("value")) + .repartition($"id") + .sortWithinPartitions($"value".desc).toDF() + + df.write + .partitionBy("id") + .parquet(tempDir) + + val dfReadIn = newSpark.read.parquet(tempDir).select("id", "value") + checkAnswer(df.filter("id = 65"), dfReadIn.filter("id = 65")) + } + newSpark.stop() + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = From d9a067cf3549af59e515b617c2d930ceb0cb129b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 10 Feb 2017 04:32:29 +0000 Subject: [PATCH 2/3] import. --- .../sql/execution/UnsafeKVExternalSorter.java | 33 ++---------- .../UnsafeKVExternalSorterSuite.scala | 52 ++----------------- 2 files changed, 8 insertions(+), 77 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 50cac27c5c43c..ee5bcfd02c79e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -19,10 +19,6 @@ import javax.annotation.Nullable; import java.io.IOException; -import java.util.List; - -import scala.collection.JavaConverters; -import scala.collection.Seq; import com.google.common.annotations.VisibleForTesting; @@ -30,7 +26,6 @@ import org.apache.spark.TaskContext; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.SerializerManager; -import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -63,7 +58,7 @@ public UnsafeKVExternalSorter( long pageSizeBytes, long numElementsForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, null, null); + numElementsForSpillThreshold, null); } public UnsafeKVExternalSorter( @@ -74,34 +69,14 @@ public UnsafeKVExternalSorter( long pageSizeBytes, long numElementsForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { - this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, map, null); - } - - public UnsafeKVExternalSorter( - StructType keySchema, - StructType valueSchema, - BlockManager blockManager, - SerializerManager serializerManager, - long pageSizeBytes, - long numElementsForSpillThreshold, - @Nullable BytesToBytesMap map, - @Nullable List ordering) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; final TaskContext taskContext = TaskContext.get(); prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); - KVComparator recordComparator = null; - if (ordering == null) { - recordComparator = new KVComparator(GenerateOrdering.create(keySchema), keySchema.length()); - } else { - Seq orderingSeq = - JavaConverters.collectionAsScalaIterableConverter(ordering).asScala().toSeq(); - recordComparator = new KVComparator((BaseOrdering)GenerateOrdering.generate(orderingSeq), - ordering.size()); - } + BaseOrdering ordering = GenerateOrdering.create(keySchema); + KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); boolean canUseRadixSort = keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)); @@ -162,7 +137,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - recordComparator, + new KVComparator(ordering, keySchema.length()), prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 06243072f3119..3d869c77e9608 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.execution import java.util.Properties -import scala.collection.JavaConverters._ import scala.util.Random import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Descending, InterpretedOrdering, SortOrder, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -111,8 +110,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { valueSchema: StructType, inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, - spill: Boolean, - sortOrdering: java.util.List[SortOrder] = null): Unit = { + spill: Boolean): Unit = { val memoryManager = new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")) val taskMemMgr = new TaskMemoryManager(memoryManager, 0) @@ -127,8 +125,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, - null, sortOrdering) + pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => @@ -148,11 +145,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { } sorter.cleanupResources() - val keyOrdering = if (sortOrdering == null) { - InterpretedOrdering.forSchema(keySchema.map(_.dataType)) - } else { - new InterpretedOrdering(sortOrdering.asScala) - } + val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType)) val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType)) val kvOrdering = new Ordering[(InternalRow, InternalRow)] { override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { @@ -211,41 +204,4 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { spill = true ) } - - test("kv sorting with records that exceed page size: with specified sort order") { - val pageSize = 128 - - val keySchema = StructType(StructField("a", BinaryType) :: StructField("b", BinaryType) :: Nil) - val valueSchema = StructType(StructField("c", BinaryType) :: Nil) - val keyExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) - val valueExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) - val keyConverter = UnsafeProjection.create(keySchema) - val valueConverter = UnsafeProjection.create(valueSchema) - - val rand = new Random() - val inputData = Seq.fill(1024) { - val kBytes1 = new Array[Byte](rand.nextInt(pageSize)) - val kBytes2 = new Array[Byte](rand.nextInt(pageSize)) - val vBytes = new Array[Byte](rand.nextInt(pageSize)) - rand.nextBytes(kBytes1) - rand.nextBytes(kBytes2) - rand.nextBytes(vBytes) - val k = - keyConverter(keyExternalConverter.apply(Row(kBytes1, kBytes2)).asInstanceOf[InternalRow]) - val v = valueConverter(valueExternalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow]) - (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) - } - - val sortOrder = SortOrder(BoundReference(0, BinaryType, nullable = true), Ascending) :: - SortOrder(BoundReference(1, BinaryType, nullable = true), Descending) :: Nil - - testKVSorter( - keySchema, - valueSchema, - inputData, - pageSize, - spill = true, - sortOrder.asJava - ) - } } From b1ce0308cf44ca5bad60a4e954f6169a3c80967e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 12 Feb 2017 03:32:58 +0000 Subject: [PATCH 3/3] Optimize the case when data is sorted by partition columns. --- .../datasources/FileFormatWriter.scala | 156 ++++++++++-------- .../datasources/FileSourceStrategySuite.scala | 2 +- 2 files changed, 87 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 4f4fb94ec882c..0c28fd64f7956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} -import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.hadoop.conf.Configuration @@ -36,7 +35,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} @@ -71,7 +70,7 @@ object FileFormatWriter extends Logging { val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long, - val orderingInPartition: Seq[SortOrder]) + val outputOrdering: Seq[SortOrder]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -129,7 +128,7 @@ object FileFormatWriter extends Logging { customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), - orderingInPartition = queryExecution.executedPlan.outputOrdering + outputOrdering = queryExecution.executedPlan.outputOrdering ) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { @@ -372,66 +371,88 @@ object FileFormatWriter extends Logging { context = taskAttemptContext) } - override def execute(iter: Iterator[InternalRow]): Set[String] = { - // If there is sort ordering in the data, we need to keep the ordering. - val orderingExpressions: Seq[Expression] = if (description.orderingInPartition.isEmpty) { - Nil - } else { - description.orderingInPartition.map(_.child) - } + // Returns the partition path given a partition key. + private val getPartitionStringFunc = UnsafeProjection.create( + Seq(Concat(partitionStringExpression)), description.partitionColumns) - // We should first sort by partition columns, then bucket id, then sort ordering in the data, - // and finally sorting columns. - val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ orderingExpressions ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) + // Returns the data columns to be written given an input row + private val getOutputRow = UnsafeProjection.create( + description.dataColumns, description.allColumns) - val bucketIdExprIndex = - sortingExpressions.length - sortColumns.length - orderingExpressions.length - 1 - - val sortingKeySchema = StructType(sortingExpressions.zipWithIndex.map { case (e, index) => - e match { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id and - // sorting order's children expressions. - case _ if index == bucketIdExprIndex => - StructField("bucketId", IntegerType, nullable = false) - case _ if index > bucketIdExprIndex => - StructField(s"_sortOrder_$index", e.dataType, e.nullable) + override def execute(iter: Iterator[InternalRow]): Set[String] = { + val outputOrderingExprs = description.outputOrdering.map(_.child) + val sortedByPartitionCols = + if (description.partitionColumns.length > outputOrderingExprs.length) { + false + } else { + description.partitionColumns.zip(outputOrderingExprs).forall { + case (partitionCol, outputOrderExpr) => partitionCol.semanticEquals(outputOrderExpr) + } } - }) - val beginSortingExpr = - sortingExpressions.length - sortColumns.length - orderingExpressions.length - val recordSortingOrder = - if (description.orderingInPartition.isEmpty) { - null - } else { - sortingExpressions.zipWithIndex.map { case (field, ordinal) => - if (ordinal < beginSortingExpr || - ordinal > beginSortingExpr + orderingExpressions.length) { - // For partition column, bucket id and sort by columns, we sort by ascending. - SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending) - } else { - // For the sort ordering of data, we need to keep its sort direction and - // null ordering. - val direction = - description.orderingInPartition(ordinal - beginSortingExpr).direction - val nullOrdering = - description.orderingInPartition(ordinal - beginSortingExpr).nullOrdering - SortOrder(BoundReference(ordinal, field.dataType, nullable = true), - direction, nullOrdering) + if (sortedByPartitionCols && bucketIdExpression.isEmpty) { + // If the input data is sorted by partition columns and no bucketing is specified, + // we don't need to sort the data by partition columns anymore. + + val getPartitioningKey = UnsafeProjection.create( + description.partitionColumns, description.allColumns) + + // If anything below fails, we should abort the task. + var recordsInFile: Long = 0L + var fileCounter = 0 + var currentKey: UnsafeRow = null + val updatedPartitions = mutable.Set[String]() + while (iter.hasNext) { + val currentRow = iter.next() + val nextKey = getPartitioningKey(currentRow).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + // See a new key - write to a new partition (new file). + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + recordsInFile = 0 + fileCounter = 0 + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + val partitionPath = getPartitionStringFunc(currentKey).getString(0) + if (partitionPath.nonEmpty) { + updatedPartitions.add(partitionPath) } - }.asJava + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + recordsInFile = 0 + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + releaseResources() + newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + } + + currentWriter.write(getOutputRow(currentRow)) + recordsInFile += 1 } + releaseResources() + updatedPartitions.toSet + } else { + executeWithSort(iter) + } + } - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create( - description.dataColumns, description.allColumns) + private def executeWithSort(iter: Iterator[InternalRow]): Set[String] = { + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val sortingExpressions: Seq[Expression] = + description.partitionColumns ++ bucketIdExpression ++ sortColumns + val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) - // Returns the partition path given a partition key. - val getPartitionStringFunc = UnsafeProjection.create( - Seq(Concat(partitionStringExpression)), description.partitionColumns) + val sortingKeySchema = StructType(sortingExpressions.map { + case a: Attribute => StructField(a.name, a.dataType, a.nullable) + // The sorting expressions are all `Attribute` except bucket id. + case _ => StructField("bucketId", IntegerType, nullable = false) + }) // Sorts the data before write, so that we only need one writer at the same time. val sorter = new UnsafeKVExternalSorter( @@ -441,28 +462,23 @@ object FileFormatWriter extends Logging { SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - null, - recordSortingOrder) + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) while (iter.hasNext) { val currentRow = iter.next() sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) } - val getBucketingKey: InternalRow => InternalRow = - if (sortColumns.isEmpty && orderingExpressions.isEmpty) { - identity - } else { - val bucketingKeyExprs = - sortingExpressions.dropRight(sortColumns.length + orderingExpressions.length) - UnsafeProjection.create(bucketingKeyExprs.zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } - val sortedIterator = sorter.sortedIterator() + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } + // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 3756b7ef9d85f..c76c38c07160d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -505,7 +505,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val df = newSpark.range(100) .select($"id", explode(array(col("id") + 1, col("id") + 2, col("id") + 3)).as("value")) .repartition($"id") - .sortWithinPartitions($"value".desc).toDF() + .sortWithinPartitions($"id", $"value".desc).toDF() df.write .partitionBy("id")