Skip to content

Commit 4bd697d

Browse files
sameeragarwalyhuai
authored andcommitted
[SPARK-13123][SQL] Implement whole state codegen for sort
## What changes were proposed in this pull request? This PR adds support for implementing whole state codegen for sort. Builds heaving on nongli 's PR: #11008 (which actually implements the feature), and adds the following changes on top: - [x] Generated code updates peak execution memory metrics - [x] Unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite` ## How was this patch tested? New unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite`. Further, all existing sort tests should pass. Author: Sameer Agarwal <sameer@databricks.com> Author: Nong Li <nong@databricks.com> Closes #11359 from sameeragarwal/sort-codegen.
1 parent 644dbb6 commit 4bd697d

File tree

5 files changed

+122
-35
lines changed

5 files changed

+122
-35
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
3737
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
3838

39-
final class UnsafeExternalRowSorter {
39+
public final class UnsafeExternalRowSorter {
4040

4141
/**
4242
* If positive, forces records to be spilled to disk at the given frequency (measured in numbers
@@ -84,8 +84,7 @@ void setTestSpillFrequency(int frequency) {
8484
testSpillFrequency = frequency;
8585
}
8686

87-
@VisibleForTesting
88-
void insertRow(UnsafeRow row) throws IOException {
87+
public void insertRow(UnsafeRow row) throws IOException {
8988
final long prefix = prefixComputer.computePrefix(row);
9089
sorter.insertRecord(
9190
row.getBaseObject(),
@@ -110,8 +109,7 @@ private void cleanupResources() {
110109
sorter.cleanupResources();
111110
}
112111

113-
@VisibleForTesting
114-
Iterator<UnsafeRow> sort() throws IOException {
112+
public Iterator<UnsafeRow> sort() throws IOException {
115113
try {
116114
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
117115
if (!sortedIterator.hasNext()) {
@@ -160,7 +158,6 @@ public UnsafeRow next() {
160158
}
161159
}
162160

163-
164161
public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
165162
while (inputIterator.hasNext()) {
166163
insertRow(inputIterator.next());

sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala

Lines changed: 99 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
20+
import org.apache.spark.{SparkEnv, TaskContext}
21+
import org.apache.spark.executor.TaskMetrics
2122
import org.apache.spark.rdd.RDD
2223
import org.apache.spark.sql.catalyst.InternalRow
2324
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
2426
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
2527
import org.apache.spark.sql.execution.metric.SQLMetrics
2628

@@ -37,7 +39,7 @@ case class Sort(
3739
global: Boolean,
3840
child: SparkPlan,
3941
testSpillFrequency: Int = 0)
40-
extends UnaryNode {
42+
extends UnaryNode with CodegenSupport {
4143

4244
override def output: Seq[Attribute] = child.output
4345

@@ -50,34 +52,36 @@ case class Sort(
5052
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
5153
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
5254

53-
protected override def doExecute(): RDD[InternalRow] = {
54-
val schema = child.schema
55-
val childOutput = child.output
55+
def createSorter(): UnsafeExternalRowSorter = {
56+
val ordering = newOrdering(sortOrder, output)
57+
58+
// The comparator for comparing prefix
59+
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
60+
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
61+
62+
// The generator for prefix
63+
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
64+
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
65+
override def computePrefix(row: InternalRow): Long = {
66+
prefixProjection.apply(row).getLong(0)
67+
}
68+
}
5669

70+
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
71+
val sorter = new UnsafeExternalRowSorter(
72+
schema, ordering, prefixComparator, prefixComputer, pageSize)
73+
if (testSpillFrequency > 0) {
74+
sorter.setTestSpillFrequency(testSpillFrequency)
75+
}
76+
sorter
77+
}
78+
79+
protected override def doExecute(): RDD[InternalRow] = {
5780
val dataSize = longMetric("dataSize")
5881
val spillSize = longMetric("spillSize")
5982

6083
child.execute().mapPartitionsInternal { iter =>
61-
val ordering = newOrdering(sortOrder, childOutput)
62-
63-
// The comparator for comparing prefix
64-
val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
65-
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
66-
67-
// The generator for prefix
68-
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
69-
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
70-
override def computePrefix(row: InternalRow): Long = {
71-
prefixProjection.apply(row).getLong(0)
72-
}
73-
}
74-
75-
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
76-
val sorter = new UnsafeExternalRowSorter(
77-
schema, ordering, prefixComparator, prefixComputer, pageSize)
78-
if (testSpillFrequency > 0) {
79-
sorter.setTestSpillFrequency(testSpillFrequency)
80-
}
84+
val sorter = createSorter()
8185

8286
val metrics = TaskContext.get().taskMetrics()
8387
// Remember spill data size of this task before execute this operator so that we can
@@ -93,4 +97,74 @@ case class Sort(
9397
sortedIterator
9498
}
9599
}
100+
101+
override def upstreams(): Seq[RDD[InternalRow]] = {
102+
child.asInstanceOf[CodegenSupport].upstreams()
103+
}
104+
105+
// Name of sorter variable used in codegen.
106+
private var sorterVariable: String = _
107+
108+
override protected def doProduce(ctx: CodegenContext): String = {
109+
val needToSort = ctx.freshName("needToSort")
110+
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
111+
112+
113+
// Initialize the class member variables. This includes the instance of the Sorter and
114+
// the iterator to return sorted rows.
115+
val thisPlan = ctx.addReferenceObj("plan", this)
116+
sorterVariable = ctx.freshName("sorter")
117+
ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
118+
s"$sorterVariable = $thisPlan.createSorter();")
119+
val metrics = ctx.freshName("metrics")
120+
ctx.addMutableState(classOf[TaskMetrics].getName, metrics,
121+
s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();")
122+
val sortedIterator = ctx.freshName("sortedIter")
123+
ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "")
124+
125+
val addToSorter = ctx.freshName("addToSorter")
126+
ctx.addNewFunction(addToSorter,
127+
s"""
128+
| private void $addToSorter() throws java.io.IOException {
129+
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
130+
| }
131+
""".stripMargin.trim)
132+
133+
val outputRow = ctx.freshName("outputRow")
134+
val dataSize = metricTerm(ctx, "dataSize")
135+
val spillSize = metricTerm(ctx, "spillSize")
136+
val spillSizeBefore = ctx.freshName("spillSizeBefore")
137+
s"""
138+
| if ($needToSort) {
139+
| $addToSorter();
140+
| Long $spillSizeBefore = $metrics.memoryBytesSpilled();
141+
| $sortedIterator = $sorterVariable.sort();
142+
| $dataSize.add($sorterVariable.getPeakMemoryUsage());
143+
| $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
144+
| $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
145+
| $needToSort = false;
146+
| }
147+
|
148+
| while ($sortedIterator.hasNext()) {
149+
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
150+
| ${consume(ctx, null, outputRow)}
151+
| if (shouldStop()) return;
152+
| }
153+
""".stripMargin.trim
154+
}
155+
156+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
157+
val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
158+
BoundReference(i, attr.dataType, attr.nullable)
159+
}
160+
161+
ctx.currentVars = input
162+
val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
163+
164+
s"""
165+
| // Convert the input attributes to an UnsafeRow and add it to the sorter
166+
| ${code.code}
167+
| $sorterVariable.insertRow(${code.value});
168+
""".stripMargin.trim
169+
}
96170
}

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
287287
${code.trim}
288288
}
289289
}
290-
"""
290+
""".trim
291291

292292
// try to compile, helpful for debug
293293
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
@@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
338338
// There is an UnsafeRow already
339339
s"""
340340
|append($row.copy());
341-
""".stripMargin
341+
""".stripMargin.trim
342342
} else {
343343
assert(input != null)
344344
if (input.nonEmpty) {
@@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
351351
s"""
352352
|${code.code.trim}
353353
|append(${code.value}.copy());
354-
""".stripMargin
354+
""".stripMargin.trim
355355
} else {
356356
// There is no columns
357357
s"""
358358
|append(unsafeRow);
359-
""".stripMargin
359+
""".stripMargin.trim
360360
}
361361
}
362362
}

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
6969
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
7070
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
7171
}
72+
73+
test("Sort should be included in WholeStageCodegen") {
74+
val df = sqlContext.range(3, 0, -1).sort(col("id"))
75+
val plan = df.queryExecution.executedPlan
76+
assert(plan.find(p =>
77+
p.isInstanceOf[WholeStageCodegen] &&
78+
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
79+
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
80+
}
7281
}

sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
154154
)
155155
}
156156

157+
test("Sort metrics") {
158+
// Assume the execution plan is
159+
// WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
160+
val df = sqlContext.range(10).sort('id)
161+
testSparkPlanMetrics(df, 2, Map.empty)
162+
}
163+
157164
test("SortMergeJoin metrics") {
158165
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
159166
// test should use the deterministic number of partitions.

0 commit comments

Comments
 (0)