Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
Expand Down Expand Up @@ -69,7 +69,8 @@ object FileFormatWriter extends Logging {
val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
val maxRecordsPerFile: Long,
val outputOrdering: Seq[SortOrder])
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
Expand Down Expand Up @@ -126,7 +127,8 @@ object FileFormatWriter extends Logging {
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
outputOrdering = queryExecution.executedPlan.outputOrdering
)

SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
Expand Down Expand Up @@ -369,7 +371,78 @@ object FileFormatWriter extends Logging {
context = taskAttemptContext)
}

// Returns the partition path given a partition key.
private val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)

// Returns the data columns to be written given an input row
private val getOutputRow = UnsafeProjection.create(
description.dataColumns, description.allColumns)

override def execute(iter: Iterator[InternalRow]): Set[String] = {
val outputOrderingExprs = description.outputOrdering.map(_.child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this duplicates too much code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly is because there are two types of iterators, one is [UnsafeRow, UnsafeRow], another is just [UnsafeRow].

val sortedByPartitionCols =
if (description.partitionColumns.length > outputOrderingExprs.length) {
false
} else {
description.partitionColumns.zip(outputOrderingExprs).forall {
case (partitionCol, outputOrderExpr) => partitionCol.semanticEquals(outputOrderExpr)
}
}

if (sortedByPartitionCols && bucketIdExpression.isEmpty) {
// If the input data is sorted by partition columns and no bucketing is specified,
// we don't need to sort the data by partition columns anymore.

val getPartitioningKey = UnsafeProjection.create(
description.partitionColumns, description.allColumns)

// If anything below fails, we should abort the task.
var recordsInFile: Long = 0L
var fileCounter = 0
var currentKey: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (iter.hasNext) {
val currentRow = iter.next()
val nextKey = getPartitioningKey(currentRow).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
// See a new key - write to a new partition (new file).
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

recordsInFile = 0
fileCounter = 0

releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
val partitionPath = getPartitionStringFunc(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter.
recordsInFile = 0
fileCounter += 1
assert(fileCounter < MAX_FILE_COUNTER,
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}

currentWriter.write(getOutputRow(currentRow))
recordsInFile += 1
}
releaseResources()
updatedPartitions.toSet
} else {
executeWithSort(iter)
}
}

private def executeWithSort(iter: Iterator[InternalRow]): Set[String] = {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] =
description.partitionColumns ++ bucketIdExpression ++ sortColumns
Expand All @@ -381,14 +454,6 @@ object FileFormatWriter extends Logging {
case _ => StructField("bucketId", IntegerType, nullable = false)
})

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
description.dataColumns, description.allColumns)

// Returns the partition path given a partition key.
val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)

// Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
Expand All @@ -404,6 +469,8 @@ object FileFormatWriter extends Logging {
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}

val sortedIterator = sorter.sortedIterator()

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
Expand All @@ -412,8 +479,6 @@ object FileFormatWriter extends Logging {
})
}

val sortedIterator = sorter.sortedIterator()

// If anything below fails, we should abort the task.
var recordsInFile: Long = 0L
var fileCounter = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}

test("SPARK-19352: Keep sort order of rows after external sorter when writing") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, this is not guaranteed, we should not test it.

This is an optimization and advanced users can leverage this to preserve the sort order, but it may change in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it.

spark.stop()
// Explicitly set memory configuration to force `UnsafeKVExternalSorter` to spill to files
// when inserting data.
val newSpark = SparkSession.builder()
.master("local")
.appName("test")
.config("spark.buffer.pageSize", "16b")
.config("spark.testing.memory", "1400")
.config("spark.memory.fraction", "0.1")
.config("spark.shuffle.sort.initialBufferSize", "2")
.config("spark.memory.offHeap.enabled", "false")
.getOrCreate()
withTempPath { path =>
val tempDir = path.getCanonicalPath
val df = newSpark.range(100)
.select($"id", explode(array(col("id") + 1, col("id") + 2, col("id") + 3)).as("value"))
.repartition($"id")
.sortWithinPartitions($"id", $"value".desc).toDF()

df.write
.partitionBy("id")
.parquet(tempDir)

val dfReadIn = newSpark.read.parquet(tempDir).select("id", "value")
checkAnswer(df.filter("id = 65"), dfReadIn.filter("id = 65"))
}
newSpark.stop()
}

// Helpers for checking the arguments passed to the FileFormat.

protected val checkPartitionSchema =
Expand Down