Skip to content

Commit e523245

Browse files
committed
address comments
1 parent 917d3fc commit e523245

File tree

2 files changed

+45
-132
lines changed

2 files changed

+45
-132
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala

Lines changed: 45 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
3333
import org.apache.spark.sql.catalyst.InternalRow
3434
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
3535
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
36-
import org.apache.spark.sql.types.{IntegerType, StructType, StringType}
36+
import org.apache.spark.sql.types.{StructField, IntegerType, StructType, StringType}
3737
import org.apache.spark.util.SerializableConfiguration
3838

3939

@@ -364,52 +364,6 @@ private[sql] class DynamicPartitionWriterContainer(
364364
}
365365
}
366366

367-
private def sortBasedWrite(
368-
sorter: UnsafeKVExternalSorter,
369-
iterator: Iterator[InternalRow],
370-
getSortingKey: UnsafeProjection,
371-
getOutputRow: UnsafeProjection,
372-
getPartitionString: UnsafeProjection,
373-
outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = {
374-
while (iterator.hasNext) {
375-
val currentRow = iterator.next()
376-
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
377-
}
378-
379-
logInfo(s"Sorting complete. Writing out partition files one at a time.")
380-
381-
val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
382-
(key1, key2) => key1 != key2
383-
} else {
384-
(key1, key2) => key1 == null || !sameBucket(key1, key2)
385-
}
386-
387-
val sortedIterator = sorter.sortedIterator()
388-
var currentKey: UnsafeRow = null
389-
var currentWriter: OutputWriter = null
390-
try {
391-
while (sortedIterator.next()) {
392-
if (needNewWriter(currentKey, sortedIterator.getKey)) {
393-
if (currentWriter != null) {
394-
currentWriter.close()
395-
}
396-
currentKey = sortedIterator.getKey.copy()
397-
logDebug(s"Writing partition: $currentKey")
398-
399-
// Either use an existing file from before, or open a new one.
400-
currentWriter = outputWriters.remove(currentKey)
401-
if (currentWriter == null) {
402-
currentWriter = newOutputWriter(currentKey, getPartitionString)
403-
}
404-
}
405-
406-
currentWriter.writeInternal(sortedIterator.getValue)
407-
}
408-
} finally {
409-
if (currentWriter != null) { currentWriter.close() }
410-
}
411-
}
412-
413367
/**
414368
* Open and returns a new OutputWriter given a partition key and optional bucket id.
415369
* If bucket id is specified, we will append it to the end of the file name, but before the
@@ -435,22 +389,18 @@ private[sql] class DynamicPartitionWriterContainer(
435389
}
436390

437391
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
438-
val outputWriters = new java.util.HashMap[InternalRow, OutputWriter]
439392
executorSideSetup(taskContext)
440393

441-
var outputWritersCleared = false
442-
443394
// We should first sort by partition columns, then bucket id, and finally sorting columns.
444-
val getSortingKey =
445-
UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
446-
447-
val sortingKeySchema = if (bucketSpec.isEmpty) {
448-
StructType.fromAttributes(partitionColumns)
449-
} else { // If it's bucketed, we should also consider bucket id as part of the key.
450-
val fields = StructType.fromAttributes(partitionColumns)
451-
.add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns)
452-
StructType(fields)
453-
}
395+
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
396+
397+
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
398+
399+
val sortingKeySchema = StructType(sortingExpressions.map {
400+
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
401+
// The sorting expressions are all `Attribute` except bucket id.
402+
case _ => StructField("bucketId", IntegerType, nullable = false)
403+
})
454404

455405
// Returns the data columns to be written given an input row
456406
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
@@ -461,54 +411,46 @@ private[sql] class DynamicPartitionWriterContainer(
461411

462412
// If anything below fails, we should abort the task.
463413
try {
464-
// If there is no sorting columns, we set sorter to null and try the hash-based writing first,
465-
// and fill the sorter if there are too many writers and we need to fall back on sorting.
466-
// If there are sorting columns, then we have to sort the data anyway, and no need to try the
467-
// hash-based writing first.
468-
var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
469-
new UnsafeKVExternalSorter(
470-
sortingKeySchema,
471-
StructType.fromAttributes(dataColumns),
472-
SparkEnv.get.blockManager,
473-
TaskContext.get().taskMemoryManager().pageSizeBytes)
414+
// Sorts the data before write, so that we only need one writer at the same time.
415+
// TODO: inject a local sort operator in planning.
416+
val sorter = new UnsafeKVExternalSorter(
417+
sortingKeySchema,
418+
StructType.fromAttributes(dataColumns),
419+
SparkEnv.get.blockManager,
420+
TaskContext.get().taskMemoryManager().pageSizeBytes)
421+
422+
while (iterator.hasNext) {
423+
val currentRow = iterator.next()
424+
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
425+
}
426+
427+
logInfo(s"Sorting complete. Writing out partition files one at a time.")
428+
429+
val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) {
430+
(key1, key2) => key1 != key2
474431
} else {
475-
null
432+
(key1, key2) => key1 == null || !sameBucket(key1, key2)
476433
}
477-
while (iterator.hasNext && sorter == null) {
478-
val inputRow = iterator.next()
479-
// When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
480-
val currentKey = getSortingKey(inputRow)
481-
var currentWriter = outputWriters.get(currentKey)
482-
483-
if (currentWriter == null) {
484-
if (outputWriters.size < maxOpenFiles) {
434+
435+
val sortedIterator = sorter.sortedIterator()
436+
var currentKey: UnsafeRow = null
437+
var currentWriter: OutputWriter = null
438+
try {
439+
while (sortedIterator.next()) {
440+
if (needNewWriter(currentKey, sortedIterator.getKey)) {
441+
if (currentWriter != null) {
442+
currentWriter.close()
443+
}
444+
currentKey = sortedIterator.getKey.copy()
445+
logDebug(s"Writing partition: $currentKey")
446+
485447
currentWriter = newOutputWriter(currentKey, getPartitionString)
486-
outputWriters.put(currentKey.copy(), currentWriter)
487-
currentWriter.writeInternal(getOutputRow(inputRow))
488-
} else {
489-
logInfo(s"Maximum partitions reached, falling back on sorting.")
490-
sorter = new UnsafeKVExternalSorter(
491-
sortingKeySchema,
492-
StructType.fromAttributes(dataColumns),
493-
SparkEnv.get.blockManager,
494-
TaskContext.get().taskMemoryManager().pageSizeBytes)
495-
sorter.insertKV(currentKey, getOutputRow(inputRow))
496448
}
497-
} else {
498-
currentWriter.writeInternal(getOutputRow(inputRow))
499-
}
500-
}
501449

502-
// If the sorter is not null that means that we reached the maxFiles above and need to finish
503-
// using external sort, or there are sorting columns and we need to sort the whole data set.
504-
if (sorter != null) {
505-
sortBasedWrite(
506-
sorter,
507-
iterator,
508-
getSortingKey,
509-
getOutputRow,
510-
getPartitionString,
511-
outputWriters)
450+
currentWriter.writeInternal(sortedIterator.getValue)
451+
}
452+
} finally {
453+
if (currentWriter != null) { currentWriter.close() }
512454
}
513455

514456
commitTask()
@@ -518,31 +460,5 @@ private[sql] class DynamicPartitionWriterContainer(
518460
abortTask()
519461
throw new SparkException("Task failed while writing rows.", cause)
520462
}
521-
522-
def clearOutputWriters(): Unit = {
523-
if (!outputWritersCleared) {
524-
outputWriters.asScala.values.foreach(_.close())
525-
outputWriters.clear()
526-
outputWritersCleared = true
527-
}
528-
}
529-
530-
def commitTask(): Unit = {
531-
try {
532-
clearOutputWriters()
533-
super.commitTask()
534-
} catch {
535-
case cause: Throwable =>
536-
throw new RuntimeException("Failed to commit task", cause)
537-
}
538-
}
539-
540-
def abortTask(): Unit = {
541-
try {
542-
clearOutputWriters()
543-
} finally {
544-
super.abortTask()
545-
}
546-
}
547463
}
548464
}

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ trait HadoopFsRelationProvider {
162162
partitionColumns: Option[StructType],
163163
parameters: Map[String, String]): HadoopFsRelation
164164

165-
// TODO: expose bucket API to users.
166165
private[sql] def createRelation(
167166
sqlContext: SQLContext,
168167
paths: Array[String],
@@ -370,7 +369,6 @@ abstract class OutputWriterFactory extends Serializable {
370369
dataSchema: StructType,
371370
context: TaskAttemptContext): OutputWriter
372371

373-
// TODO: expose bucket API to users.
374372
private[sql] def newInstance(
375373
path: String,
376374
bucketId: Option[Int],
@@ -460,7 +458,6 @@ abstract class HadoopFsRelation private[sql](
460458

461459
private var _partitionSpec: PartitionSpec = _
462460

463-
// TODO: expose bucket API to users.
464461
private[sql] def bucketSpec: Option[BucketSpec] = None
465462

466463
private class FileStatusCache {

0 commit comments

Comments
 (0)