1717
1818package org .apache .spark .sql .execution
1919
20- import org .apache .spark .{InternalAccumulator , SparkEnv , TaskContext }
20+ import org .apache .spark .{SparkEnv , TaskContext }
21+ import org .apache .spark .executor .TaskMetrics
2122import org .apache .spark .rdd .RDD
2223import org .apache .spark .sql .catalyst .InternalRow
2324import org .apache .spark .sql .catalyst .expressions ._
25+ import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode , GenerateUnsafeProjection }
2426import org .apache .spark .sql .catalyst .plans .physical .{Distribution , OrderedDistribution , UnspecifiedDistribution }
2527import org .apache .spark .sql .execution .metric .SQLMetrics
2628
@@ -37,7 +39,7 @@ case class Sort(
3739 global : Boolean ,
3840 child : SparkPlan ,
3941 testSpillFrequency : Int = 0 )
40- extends UnaryNode {
42+ extends UnaryNode with CodegenSupport {
4143
4244 override def output : Seq [Attribute ] = child.output
4345
@@ -50,34 +52,36 @@ case class Sort(
5052 " dataSize" -> SQLMetrics .createSizeMetric(sparkContext, " data size" ),
5153 " spillSize" -> SQLMetrics .createSizeMetric(sparkContext, " spill size" ))
5254
53- protected override def doExecute (): RDD [InternalRow ] = {
54- val schema = child.schema
55- val childOutput = child.output
55+ def createSorter (): UnsafeExternalRowSorter = {
56+ val ordering = newOrdering(sortOrder, output)
57+
58+ // The comparator for comparing prefix
59+ val boundSortExpression = BindReferences .bindReference(sortOrder.head, output)
60+ val prefixComparator = SortPrefixUtils .getPrefixComparator(boundSortExpression)
61+
62+ // The generator for prefix
63+ val prefixProjection = UnsafeProjection .create(Seq (SortPrefix (boundSortExpression)))
64+ val prefixComputer = new UnsafeExternalRowSorter .PrefixComputer {
65+ override def computePrefix (row : InternalRow ): Long = {
66+ prefixProjection.apply(row).getLong(0 )
67+ }
68+ }
5669
70+ val pageSize = SparkEnv .get.memoryManager.pageSizeBytes
71+ val sorter = new UnsafeExternalRowSorter (
72+ schema, ordering, prefixComparator, prefixComputer, pageSize)
73+ if (testSpillFrequency > 0 ) {
74+ sorter.setTestSpillFrequency(testSpillFrequency)
75+ }
76+ sorter
77+ }
78+
79+ protected override def doExecute (): RDD [InternalRow ] = {
5780 val dataSize = longMetric(" dataSize" )
5881 val spillSize = longMetric(" spillSize" )
5982
6083 child.execute().mapPartitionsInternal { iter =>
61- val ordering = newOrdering(sortOrder, childOutput)
62-
63- // The comparator for comparing prefix
64- val boundSortExpression = BindReferences .bindReference(sortOrder.head, childOutput)
65- val prefixComparator = SortPrefixUtils .getPrefixComparator(boundSortExpression)
66-
67- // The generator for prefix
68- val prefixProjection = UnsafeProjection .create(Seq (SortPrefix (boundSortExpression)))
69- val prefixComputer = new UnsafeExternalRowSorter .PrefixComputer {
70- override def computePrefix (row : InternalRow ): Long = {
71- prefixProjection.apply(row).getLong(0 )
72- }
73- }
74-
75- val pageSize = SparkEnv .get.memoryManager.pageSizeBytes
76- val sorter = new UnsafeExternalRowSorter (
77- schema, ordering, prefixComparator, prefixComputer, pageSize)
78- if (testSpillFrequency > 0 ) {
79- sorter.setTestSpillFrequency(testSpillFrequency)
80- }
84+ val sorter = createSorter()
8185
8286 val metrics = TaskContext .get().taskMetrics()
8387 // Remember spill data size of this task before execute this operator so that we can
@@ -93,4 +97,74 @@ case class Sort(
9397 sortedIterator
9498 }
9599 }
100+
101+ override def upstreams (): Seq [RDD [InternalRow ]] = {
102+ child.asInstanceOf [CodegenSupport ].upstreams()
103+ }
104+
105+ // Name of sorter variable used in codegen.
106+ private var sorterVariable : String = _
107+
108+ override protected def doProduce (ctx : CodegenContext ): String = {
109+ val needToSort = ctx.freshName(" needToSort" )
110+ ctx.addMutableState(" boolean" , needToSort, s " $needToSort = true; " )
111+
112+
113+ // Initialize the class member variables. This includes the instance of the Sorter and
114+ // the iterator to return sorted rows.
115+ val thisPlan = ctx.addReferenceObj(" plan" , this )
116+ sorterVariable = ctx.freshName(" sorter" )
117+ ctx.addMutableState(classOf [UnsafeExternalRowSorter ].getName, sorterVariable,
118+ s " $sorterVariable = $thisPlan.createSorter(); " )
119+ val metrics = ctx.freshName(" metrics" )
120+ ctx.addMutableState(classOf [TaskMetrics ].getName, metrics,
121+ s " $metrics = org.apache.spark.TaskContext.get().taskMetrics(); " )
122+ val sortedIterator = ctx.freshName(" sortedIter" )
123+ ctx.addMutableState(" scala.collection.Iterator<UnsafeRow>" , sortedIterator, " " )
124+
125+ val addToSorter = ctx.freshName(" addToSorter" )
126+ ctx.addNewFunction(addToSorter,
127+ s """
128+ | private void $addToSorter() throws java.io.IOException {
129+ | ${child.asInstanceOf [CodegenSupport ].produce(ctx, this )}
130+ | }
131+ """ .stripMargin.trim)
132+
133+ val outputRow = ctx.freshName(" outputRow" )
134+ val dataSize = metricTerm(ctx, " dataSize" )
135+ val spillSize = metricTerm(ctx, " spillSize" )
136+ val spillSizeBefore = ctx.freshName(" spillSizeBefore" )
137+ s """
138+ | if ( $needToSort) {
139+ | $addToSorter();
140+ | Long $spillSizeBefore = $metrics.memoryBytesSpilled();
141+ | $sortedIterator = $sorterVariable.sort();
142+ | $dataSize.add( $sorterVariable.getPeakMemoryUsage());
143+ | $spillSize.add( $metrics.memoryBytesSpilled() - $spillSizeBefore);
144+ | $metrics.incPeakExecutionMemory( $sorterVariable.getPeakMemoryUsage());
145+ | $needToSort = false;
146+ | }
147+ |
148+ | while ( $sortedIterator.hasNext()) {
149+ | UnsafeRow $outputRow = (UnsafeRow) $sortedIterator.next();
150+ | ${consume(ctx, null , outputRow)}
151+ | if (shouldStop()) return;
152+ | }
153+ """ .stripMargin.trim
154+ }
155+
156+ override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
157+ val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
158+ BoundReference (i, attr.dataType, attr.nullable)
159+ }
160+
161+ ctx.currentVars = input
162+ val code = GenerateUnsafeProjection .createCode(ctx, colExprs)
163+
164+ s """
165+ | // Convert the input attributes to an UnsafeRow and add it to the sorter
166+ | ${code.code}
167+ | $sorterVariable.insertRow( ${code.value});
168+ """ .stripMargin.trim
169+ }
96170}
0 commit comments