Skip to content

Commit e3133f4

Browse files
cloud-fankiszk
authored andcommitted
[SPARK-25497][SQL] Limit operation within whole stage codegen should not consume all the inputs
## What changes were proposed in this pull request? This PR is inspired by #22524, but proposes a safer fix. The current limit whole stage codegen has 2 problems: 1. It's only applied to `InputAdapter`, many leaf nodes can't stop earlier w.r.t. limit. 2. It needs to override a method, which will break if we have more than one limit in the whole-stage. The first problem is easy to fix, just figure out which nodes can stop earlier w.r.t. limit, and update them. This PR updates `RangeExec`, `ColumnarBatchScan`, `SortExec`, `HashAggregateExec`. The second problem is hard to fix. This PR proposes to propagate the limit counter variable name upstream, so that the upstream leaf/blocking nodes can check the limit counter and quit the loop earlier. For better performance, the implementation here follows `CodegenSupport.needStopCheck`, so that we only codegen the check only if there is limit in the query. For columnar node like range, we check the limit counter per-batch instead of per-row, to make the inner loop tight and fast. Why this is safer? 1. the leaf/blocking nodes don't have to check the limit counter and stop earlier. It's only for performance. (this is same as before) 2. The blocking operators can stop propagating the limit counter name, because the counter of limit after blocking operators will never increase, before blocking operators consume all the data from upstream operators. So the upstream operators don't care about limit after blocking operators. This is also for performance only, it's OK if we forget to do it for some new blocking operators. ## How was this patch tested? a new test Closes #22630 from cloud-fan/limit. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Kazuaki Ishizaki <ishizaki@jp.ibm.com>
1 parent 46fe408 commit e3133f4

File tree

8 files changed

+215
-125
lines changed

8 files changed

+215
-125
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,6 @@ public void append(InternalRow row) {
7373
currentRows.add(row);
7474
}
7575

76-
/**
77-
* Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]].
78-
*
79-
* If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
80-
* This interface is mainly used to limit the number of input rows.
81-
*/
82-
public boolean stopEarly() {
83-
return false;
84-
}
85-
8676
/**
8777
* Returns whether `processNext()` should stop processing next row from `input` or not.
8878
*

sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
136136
|if ($batch == null) {
137137
| $nextBatchFuncName();
138138
|}
139-
|while ($batch != null) {
139+
|while ($limitNotReachedCond $batch != null) {
140140
| int $numRows = $batch.numRows();
141141
| int $localEnd = $numRows - $idx;
142142
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
@@ -166,7 +166,7 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
166166
}
167167
val inputRow = if (needsUnsafeRowConversion) null else row
168168
s"""
169-
|while ($input.hasNext()) {
169+
|while ($limitNotReachedCond $input.hasNext()) {
170170
| InternalRow $row = (InternalRow) $input.next();
171171
| $numOutputRows.add(1);
172172
| ${consume(ctx, outputVars, inputRow).trim}

sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ case class SortExec(
3939
global: Boolean,
4040
child: SparkPlan,
4141
testSpillFrequency: Int = 0)
42-
extends UnaryExecNode with CodegenSupport {
42+
extends UnaryExecNode with BlockingOperatorWithCodegen {
4343

4444
override def output: Seq[Attribute] = child.output
4545

@@ -124,14 +124,6 @@ case class SortExec(
124124
// Name of sorter variable used in codegen.
125125
private var sorterVariable: String = _
126126

127-
// The result rows come from the sort buffer, so this operator doesn't need to copy its result
128-
// even if its child does.
129-
override def needCopyResult: Boolean = false
130-
131-
// Sort operator always consumes all the input rows before outputting any result, so we don't need
132-
// a stop check before sorting.
133-
override def needStopCheck: Boolean = false
134-
135127
override protected def doProduce(ctx: CodegenContext): String = {
136128
val needToSort =
137129
ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
@@ -172,7 +164,7 @@ case class SortExec(
172164
| $needToSort = false;
173165
| }
174166
|
175-
| while ($sortedIterator.hasNext()) {
167+
| while ($limitNotReachedCond $sortedIterator.hasNext()) {
176168
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
177169
| ${consume(ctx, null, outputRow)}
178170
| if (shouldStop()) return;

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,61 @@ trait CodegenSupport extends SparkPlan {
345345
* don't require shouldStop() in the loop of producing rows.
346346
*/
347347
def needStopCheck: Boolean = parent.needStopCheck
348+
349+
/**
350+
* A sequence of checks which evaluate to true if the downstream Limit operators have not received
351+
* enough records and reached the limit. If current node is a data producing node, it can leverage
352+
* this information to stop producing data and complete the data flow earlier. Common data
353+
* producing nodes are leaf nodes like Range and Scan, and blocking nodes like Sort and Aggregate.
354+
* These checks should be put into the loop condition of the data producing loop.
355+
*/
356+
def limitNotReachedChecks: Seq[String] = parent.limitNotReachedChecks
357+
358+
/**
359+
* A helper method to generate the data producing loop condition according to the
360+
* limit-not-reached checks.
361+
*/
362+
final def limitNotReachedCond: String = {
363+
// InputAdapter is also a leaf node.
364+
val isLeafNode = children.isEmpty || this.isInstanceOf[InputAdapter]
365+
if (!isLeafNode && !this.isInstanceOf[BlockingOperatorWithCodegen]) {
366+
val errMsg = "Only leaf nodes and blocking nodes need to call 'limitNotReachedCond' " +
367+
"in its data producing loop."
368+
if (Utils.isTesting) {
369+
throw new IllegalStateException(errMsg)
370+
} else {
371+
logWarning(s"[BUG] $errMsg Please open a JIRA ticket to report it.")
372+
}
373+
}
374+
if (parent.limitNotReachedChecks.isEmpty) {
375+
""
376+
} else {
377+
parent.limitNotReachedChecks.mkString("", " && ", " &&")
378+
}
379+
}
380+
}
381+
382+
/**
383+
* A special kind of operators which support whole stage codegen. Blocking means these operators
384+
* will consume all the inputs first, before producing output. Typical blocking operators are
385+
* sort and aggregate.
386+
*/
387+
trait BlockingOperatorWithCodegen extends CodegenSupport {
388+
389+
// Blocking operators usually have some kind of buffer to keep the data before producing them, so
390+
// then don't to copy its result even if its child does.
391+
override def needCopyResult: Boolean = false
392+
393+
// Blocking operators always consume all the input first, so its upstream operators don't need a
394+
// stop check.
395+
override def needStopCheck: Boolean = false
396+
397+
// Blocking operators need to consume all the inputs before producing any output. This means,
398+
// Limit operator after this blocking operator will never reach its limit during the execution of
399+
// this blocking operator's upstream operators. Here we override this method to return Nil, so
400+
// that upstream operators will not generate useless conditions (which are always evaluated to
401+
// false) for the Limit operators after this blocking operator.
402+
override def limitNotReachedChecks: Seq[String] = Nil
348403
}
349404

350405

@@ -381,7 +436,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
381436
forceInline = true)
382437
val row = ctx.freshName("row")
383438
s"""
384-
| while ($input.hasNext() && !stopEarly()) {
439+
| while ($limitNotReachedCond $input.hasNext()) {
385440
| InternalRow $row = (InternalRow) $input.next();
386441
| ${consume(ctx, null, row).trim}
387442
| if (shouldStop()) return;
@@ -677,6 +732,8 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
677732

678733
override def needStopCheck: Boolean = true
679734

735+
override def limitNotReachedChecks: Seq[String] = Nil
736+
680737
override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
681738
}
682739

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ case class HashAggregateExec(
4545
initialInputBufferOffset: Int,
4646
resultExpressions: Seq[NamedExpression],
4747
child: SparkPlan)
48-
extends UnaryExecNode with CodegenSupport {
48+
extends UnaryExecNode with BlockingOperatorWithCodegen {
4949

5050
private[this] val aggregateBufferAttributes = {
5151
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -151,14 +151,6 @@ case class HashAggregateExec(
151151
child.asInstanceOf[CodegenSupport].inputRDDs()
152152
}
153153

154-
// The result rows come from the aggregate buffer, or a single row(no grouping keys), so this
155-
// operator doesn't need to copy its result even if its child does.
156-
override def needCopyResult: Boolean = false
157-
158-
// Aggregate operator always consumes all the input rows before outputting any result, so we
159-
// don't need a stop check before aggregating.
160-
override def needStopCheck: Boolean = false
161-
162154
protected override def doProduce(ctx: CodegenContext): String = {
163155
if (groupingExpressions.isEmpty) {
164156
doProduceWithoutKeys(ctx)
@@ -705,13 +697,16 @@ case class HashAggregateExec(
705697

706698
def outputFromRegularHashMap: String = {
707699
s"""
708-
|while ($iterTerm.next()) {
700+
|while ($limitNotReachedCond $iterTerm.next()) {
709701
| UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
710702
| UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
711703
| $outputFunc($keyTerm, $bufferTerm);
712-
|
713704
| if (shouldStop()) return;
714705
|}
706+
|$iterTerm.close();
707+
|if ($sorterTerm == null) {
708+
| $hashMapTerm.free();
709+
|}
715710
""".stripMargin
716711
}
717712

@@ -728,11 +723,6 @@ case class HashAggregateExec(
728723
// output the result
729724
$outputFromFastHashMap
730725
$outputFromRegularHashMap
731-
732-
$iterTerm.close();
733-
if ($sorterTerm == null) {
734-
$hashMapTerm.free();
735-
}
736726
"""
737727
}
738728

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
378378
val numOutput = metricTerm(ctx, "numOutputRows")
379379

380380
val initTerm = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "initRange")
381-
val number = ctx.addMutableState(CodeGenerator.JAVA_LONG, "number")
381+
val nextIndex = ctx.addMutableState(CodeGenerator.JAVA_LONG, "nextIndex")
382382

383383
val value = ctx.freshName("value")
384384
val ev = ExprCode.forNonNullValue(JavaCode.variable(value, LongType))
@@ -397,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
397397
// within a batch, while the code in the outer loop is setting batch parameters and updating
398398
// the metrics.
399399

400-
// Once number == batchEnd, it's time to progress to the next batch.
400+
// Once nextIndex == batchEnd, it's time to progress to the next batch.
401401
val batchEnd = ctx.addMutableState(CodeGenerator.JAVA_LONG, "batchEnd")
402402

403403
// How many values should still be generated by this range operator.
@@ -421,13 +421,13 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
421421
|
422422
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
423423
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
424-
| $number = Long.MAX_VALUE;
424+
| $nextIndex = Long.MAX_VALUE;
425425
| } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
426-
| $number = Long.MIN_VALUE;
426+
| $nextIndex = Long.MIN_VALUE;
427427
| } else {
428-
| $number = st.longValue();
428+
| $nextIndex = st.longValue();
429429
| }
430-
| $batchEnd = $number;
430+
| $batchEnd = $nextIndex;
431431
|
432432
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
433433
| .multiply(step).add(start);
@@ -440,7 +440,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
440440
| }
441441
|
442442
| $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
443-
| $BigInt.valueOf($number));
443+
| $BigInt.valueOf($nextIndex));
444444
| $numElementsTodo = startToEnd.divide(step).longValue();
445445
| if ($numElementsTodo < 0) {
446446
| $numElementsTodo = 0;
@@ -452,46 +452,73 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
452452

453453
val localIdx = ctx.freshName("localIdx")
454454
val localEnd = ctx.freshName("localEnd")
455-
val range = ctx.freshName("range")
456455
val shouldStop = if (parent.needStopCheck) {
457-
s"if (shouldStop()) { $number = $value + ${step}L; return; }"
456+
s"if (shouldStop()) { $nextIndex = $value + ${step}L; return; }"
458457
} else {
459458
"// shouldStop check is eliminated"
460459
}
460+
val loopCondition = if (limitNotReachedChecks.isEmpty) {
461+
"true"
462+
} else {
463+
limitNotReachedChecks.mkString(" && ")
464+
}
465+
466+
// An overview of the Range processing.
467+
//
468+
// For each partition, the Range task needs to produce records from partition start(inclusive)
469+
// to end(exclusive). For better performance, we separate the partition range into batches, and
470+
// use 2 loops to produce data. The outer while loop is used to iterate batches, and the inner
471+
// for loop is used to iterate records inside a batch.
472+
//
473+
// `nextIndex` tracks the index of the next record that is going to be consumed, initialized
474+
// with partition start. `batchEnd` tracks the end index of the current batch, initialized
475+
// with `nextIndex`. In the outer loop, we first check if `nextIndex == batchEnd`. If it's true,
476+
// it means the current batch is fully consumed, and we will update `batchEnd` to process the
477+
// next batch. If `batchEnd` reaches partition end, exit the outer loop. Finally we enter the
478+
// inner loop. Note that, when we enter inner loop, `nextIndex` must be different from
479+
// `batchEnd`, otherwise we already exit the outer loop.
480+
//
481+
// The inner loop iterates from 0 to `localEnd`, which is calculated by
482+
// `(batchEnd - nextIndex) / step`. Since `batchEnd` is increased by `nextBatchTodo * step` in
483+
// the outer loop, and initialized with `nextIndex`, so `batchEnd - nextIndex` is always
484+
// divisible by `step`. The `nextIndex` is increased by `step` during each iteration, and ends
485+
// up being equal to `batchEnd` when the inner loop finishes.
486+
//
487+
// The inner loop can be interrupted, if the query has produced at least one result row, so that
488+
// we don't buffer too many result rows and waste memory. It's ok to interrupt the inner loop,
489+
// because `nextIndex` will be updated before interrupting.
490+
461491
s"""
462492
| // initialize Range
463493
| if (!$initTerm) {
464494
| $initTerm = true;
465495
| $initRangeFuncName(partitionIndex);
466496
| }
467497
|
468-
| while (true) {
469-
| long $range = $batchEnd - $number;
470-
| if ($range != 0L) {
471-
| int $localEnd = (int)($range / ${step}L);
472-
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
473-
| long $value = ((long)$localIdx * ${step}L) + $number;
474-
| ${consume(ctx, Seq(ev))}
475-
| $shouldStop
498+
| while ($loopCondition) {
499+
| if ($nextIndex == $batchEnd) {
500+
| long $nextBatchTodo;
501+
| if ($numElementsTodo > ${batchSize}L) {
502+
| $nextBatchTodo = ${batchSize}L;
503+
| $numElementsTodo -= ${batchSize}L;
504+
| } else {
505+
| $nextBatchTodo = $numElementsTodo;
506+
| $numElementsTodo = 0;
507+
| if ($nextBatchTodo == 0) break;
476508
| }
477-
| $number = $batchEnd;
509+
| $numOutput.add($nextBatchTodo);
510+
| $inputMetrics.incRecordsRead($nextBatchTodo);
511+
| $batchEnd += $nextBatchTodo * ${step}L;
478512
| }
479513
|
480-
| $taskContext.killTaskIfInterrupted();
481-
|
482-
| long $nextBatchTodo;
483-
| if ($numElementsTodo > ${batchSize}L) {
484-
| $nextBatchTodo = ${batchSize}L;
485-
| $numElementsTodo -= ${batchSize}L;
486-
| } else {
487-
| $nextBatchTodo = $numElementsTodo;
488-
| $numElementsTodo = 0;
489-
| if ($nextBatchTodo == 0) break;
514+
| int $localEnd = (int)(($batchEnd - $nextIndex) / ${step}L);
515+
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
516+
| long $value = ((long)$localIdx * ${step}L) + $nextIndex;
517+
| ${consume(ctx, Seq(ev))}
518+
| $shouldStop
490519
| }
491-
| $numOutput.add($nextBatchTodo);
492-
| $inputMetrics.incRecordsRead($nextBatchTodo);
493-
|
494-
| $batchEnd += $nextBatchTodo * ${step}L;
520+
| $nextIndex = $batchEnd;
521+
| $taskContext.killTaskIfInterrupted();
495522
| }
496523
""".stripMargin
497524
}

0 commit comments

Comments
 (0)