@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
3333import org .apache .spark .sql .catalyst .InternalRow
3434import org .apache .spark .sql .execution .UnsafeKVExternalSorter
3535import org .apache .spark .sql .sources .{HadoopFsRelation , OutputWriter , OutputWriterFactory }
36- import org .apache .spark .sql .types .{IntegerType , StructType , StringType }
36+ import org .apache .spark .sql .types .{StructField , IntegerType , StructType , StringType }
3737import org .apache .spark .util .SerializableConfiguration
3838
3939
@@ -364,52 +364,6 @@ private[sql] class DynamicPartitionWriterContainer(
364364 }
365365 }
366366
367- private def sortBasedWrite (
368- sorter : UnsafeKVExternalSorter ,
369- iterator : Iterator [InternalRow ],
370- getSortingKey : UnsafeProjection ,
371- getOutputRow : UnsafeProjection ,
372- getPartitionString : UnsafeProjection ,
373- outputWriters : java.util.HashMap [InternalRow , OutputWriter ]): Unit = {
374- while (iterator.hasNext) {
375- val currentRow = iterator.next()
376- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
377- }
378-
379- logInfo(s " Sorting complete. Writing out partition files one at a time. " )
380-
381- val needNewWriter : (UnsafeRow , UnsafeRow ) => Boolean = if (sortColumns.isEmpty) {
382- (key1, key2) => key1 != key2
383- } else {
384- (key1, key2) => key1 == null || ! sameBucket(key1, key2)
385- }
386-
387- val sortedIterator = sorter.sortedIterator()
388- var currentKey : UnsafeRow = null
389- var currentWriter : OutputWriter = null
390- try {
391- while (sortedIterator.next()) {
392- if (needNewWriter(currentKey, sortedIterator.getKey)) {
393- if (currentWriter != null ) {
394- currentWriter.close()
395- }
396- currentKey = sortedIterator.getKey.copy()
397- logDebug(s " Writing partition: $currentKey" )
398-
399- // Either use an existing file from before, or open a new one.
400- currentWriter = outputWriters.remove(currentKey)
401- if (currentWriter == null ) {
402- currentWriter = newOutputWriter(currentKey, getPartitionString)
403- }
404- }
405-
406- currentWriter.writeInternal(sortedIterator.getValue)
407- }
408- } finally {
409- if (currentWriter != null ) { currentWriter.close() }
410- }
411- }
412-
413367 /**
414368 * Open and returns a new OutputWriter given a partition key and optional bucket id.
415369 * If bucket id is specified, we will append it to the end of the file name, but before the
@@ -435,22 +389,18 @@ private[sql] class DynamicPartitionWriterContainer(
435389 }
436390
437391 def writeRows (taskContext : TaskContext , iterator : Iterator [InternalRow ]): Unit = {
438- val outputWriters = new java.util.HashMap [InternalRow , OutputWriter ]
439392 executorSideSetup(taskContext)
440393
441- var outputWritersCleared = false
442-
443394 // We should first sort by partition columns, then bucket id, and finally sorting columns.
444- val getSortingKey =
445- UnsafeProjection .create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema)
446-
447- val sortingKeySchema = if (bucketSpec.isEmpty) {
448- StructType .fromAttributes(partitionColumns)
449- } else { // If it's bucketed, we should also consider bucket id as part of the key.
450- val fields = StructType .fromAttributes(partitionColumns)
451- .add(" bucketId" , IntegerType , nullable = false ) ++ StructType .fromAttributes(sortColumns)
452- StructType (fields)
453- }
395+ val sortingExpressions : Seq [Expression ] = partitionColumns ++ bucketIdExpression ++ sortColumns
396+
397+ val getSortingKey = UnsafeProjection .create(sortingExpressions, inputSchema)
398+
399+ val sortingKeySchema = StructType (sortingExpressions.map {
400+ case a : Attribute => StructField (a.name, a.dataType, a.nullable)
401+ // The sorting expressions are all `Attribute` except bucket id.
402+ case _ => StructField (" bucketId" , IntegerType , nullable = false )
403+ })
454404
455405 // Returns the data columns to be written given an input row
456406 val getOutputRow = UnsafeProjection .create(dataColumns, inputSchema)
@@ -461,54 +411,46 @@ private[sql] class DynamicPartitionWriterContainer(
461411
462412 // If anything below fails, we should abort the task.
463413 try {
464- // If there is no sorting columns, we set sorter to null and try the hash-based writing first,
465- // and fill the sorter if there are too many writers and we need to fall back on sorting.
466- // If there are sorting columns, then we have to sort the data anyway, and no need to try the
467- // hash-based writing first.
468- var sorter : UnsafeKVExternalSorter = if (sortColumns.nonEmpty) {
469- new UnsafeKVExternalSorter (
470- sortingKeySchema,
471- StructType .fromAttributes(dataColumns),
472- SparkEnv .get.blockManager,
473- TaskContext .get().taskMemoryManager().pageSizeBytes)
414+ // Sorts the data before write, so that we only need one writer at the same time.
415+ // TODO: inject a local sort operator in planning.
416+ val sorter = new UnsafeKVExternalSorter (
417+ sortingKeySchema,
418+ StructType .fromAttributes(dataColumns),
419+ SparkEnv .get.blockManager,
420+ TaskContext .get().taskMemoryManager().pageSizeBytes)
421+
422+ while (iterator.hasNext) {
423+ val currentRow = iterator.next()
424+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
425+ }
426+
427+ logInfo(s " Sorting complete. Writing out partition files one at a time. " )
428+
429+ val needNewWriter : (UnsafeRow , UnsafeRow ) => Boolean = if (sortColumns.isEmpty) {
430+ (key1, key2) => key1 != key2
474431 } else {
475- null
432+ (key1, key2) => key1 == null || ! sameBucket(key1, key2)
476433 }
477- while (iterator.hasNext && sorter == null ) {
478- val inputRow = iterator.next()
479- // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key.
480- val currentKey = getSortingKey(inputRow)
481- var currentWriter = outputWriters.get(currentKey)
482-
483- if (currentWriter == null ) {
484- if (outputWriters.size < maxOpenFiles) {
434+
435+ val sortedIterator = sorter.sortedIterator()
436+ var currentKey : UnsafeRow = null
437+ var currentWriter : OutputWriter = null
438+ try {
439+ while (sortedIterator.next()) {
440+ if (needNewWriter(currentKey, sortedIterator.getKey)) {
441+ if (currentWriter != null ) {
442+ currentWriter.close()
443+ }
444+ currentKey = sortedIterator.getKey.copy()
445+ logDebug(s " Writing partition: $currentKey" )
446+
485447 currentWriter = newOutputWriter(currentKey, getPartitionString)
486- outputWriters.put(currentKey.copy(), currentWriter)
487- currentWriter.writeInternal(getOutputRow(inputRow))
488- } else {
489- logInfo(s " Maximum partitions reached, falling back on sorting. " )
490- sorter = new UnsafeKVExternalSorter (
491- sortingKeySchema,
492- StructType .fromAttributes(dataColumns),
493- SparkEnv .get.blockManager,
494- TaskContext .get().taskMemoryManager().pageSizeBytes)
495- sorter.insertKV(currentKey, getOutputRow(inputRow))
496448 }
497- } else {
498- currentWriter.writeInternal(getOutputRow(inputRow))
499- }
500- }
501449
502- // If the sorter is not null that means that we reached the maxFiles above and need to finish
503- // using external sort, or there are sorting columns and we need to sort the whole data set.
504- if (sorter != null ) {
505- sortBasedWrite(
506- sorter,
507- iterator,
508- getSortingKey,
509- getOutputRow,
510- getPartitionString,
511- outputWriters)
450+ currentWriter.writeInternal(sortedIterator.getValue)
451+ }
452+ } finally {
453+ if (currentWriter != null ) { currentWriter.close() }
512454 }
513455
514456 commitTask()
@@ -518,31 +460,5 @@ private[sql] class DynamicPartitionWriterContainer(
518460 abortTask()
519461 throw new SparkException (" Task failed while writing rows." , cause)
520462 }
521-
522- def clearOutputWriters (): Unit = {
523- if (! outputWritersCleared) {
524- outputWriters.asScala.values.foreach(_.close())
525- outputWriters.clear()
526- outputWritersCleared = true
527- }
528- }
529-
530- def commitTask (): Unit = {
531- try {
532- clearOutputWriters()
533- super .commitTask()
534- } catch {
535- case cause : Throwable =>
536- throw new RuntimeException (" Failed to commit task" , cause)
537- }
538- }
539-
540- def abortTask (): Unit = {
541- try {
542- clearOutputWriters()
543- } finally {
544- super .abortTask()
545- }
546- }
547463 }
548464}
0 commit comments