Skip to content

Commit

Permalink
Minimize the necessary changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 26, 2016
1 parent 44c64e0 commit 86b4e42
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,10 @@ public abstract class BufferedRowIterator {
protected int partitionIndex = -1;

public boolean hasNext() throws IOException {
if (!shouldStop()) {
if (currentRows.isEmpty()) {
processNext();
}
boolean hasNext = !currentRows.isEmpty();
// If no more data available, releases resource if necessary.
if (!hasNext) {
releaseResource();
}
return hasNext;
return !currentRows.isEmpty();
}

public InternalRow next() {
Expand Down Expand Up @@ -96,9 +91,4 @@ protected void incPeakExecutionMemory(long size) {
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected abstract void processNext() throws IOException;

/**
* Releases resources if necessary. No-op in default.
*/
protected void releaseResource() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.execution.metric.SQLMetrics


Expand All @@ -29,13 +28,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
*/
case class LocalTableScanExec(
output: Seq[Attribute],
rows: Seq[InternalRow]) extends LeafExecNode with CodegenSupport {
rows: Seq[InternalRow]) extends LeafExecNode {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(rdd)

private val unsafeRows: Array[InternalRow] = {
if (rows.isEmpty) {
Array.empty
Expand All @@ -50,22 +47,6 @@ case class LocalTableScanExec(

private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism)

protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val input = ctx.freshName("input")
// Right now, LocalTableScanExec is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
val row = ctx.freshName("row")
s"""
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| $numOutput.add(1);
| ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
| }
""".stripMargin
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
rdd.map { r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,6 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
override def outputPartitioning: Partitioning = child.outputPartitioning

override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def executeCollect(): Array[InternalRow] = child match {
// A physical Limit operator has optimized executeCollect which scans increasingly
// the RDD to get the limited items, without fully materializing the RDD.
case g: GlobalLimitExec => g.executeCollect()
case _ => super.executeCollect()
}

override lazy val metrics = Map(
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,6 @@ case class HashAggregateExec(
""".stripMargin
}

ctx.addNewFunction("releaseResource", s"""
@Override
protected void releaseResource() {
$iterTerm.close();
if ($sorterTerm == null) {
$hashMapTerm.free();
}
}
""")

val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
Expand Down
18 changes: 5 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.util.Utils

/**
Expand All @@ -33,17 +33,8 @@ import org.apache.spark.util.Utils
trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
val limit: Int
override def output: Seq[Attribute] = child.output
override def executeCollect(): Array[InternalRow] = {
child match {
// Shuffling is inserted under whole stage codegen.
case InputAdapter(ShuffleExchange(_, WholeStageCodegenExec(l: LocalLimitExec), _)) =>
l.executeCollect()
// Shuffling is inserted without whole stage codegen.
case ShuffleExchange(_, l: LocalLimitExec, _) => l.executeCollect()
// No shuffling happened.
case _ => child.executeTake(limit)
}
}
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
override def executeTake(n: Int): Array[InternalRow] = child.executeTake(limit)

protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit)
Expand Down Expand Up @@ -72,8 +63,9 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
| if ($countTerm == $limit) $stopEarly = true;
| ${consume(ctx, input)}
| } else {
| $stopEarly = true;
| }
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metrics1.contains("numOutputRows"))
assert(metrics1("numOutputRows").value === 3)

val df2 = spark.createDataset(Seq(1, 2, 3)).limit(2)
df2.collect()
val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics
assert(metrics2.contains("numOutputRows"))
assert(metrics2("numOutputRows").value === 2)
// Due to the iteration processing of WholeStage Codegen,
// limit(n) operator will pull n + 1 items from previous operator. So we tune off
// WholeStage Codegen here to test the metrics.
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
// A physical GlobalLimitExec and a LocalLimitExec constitute a logical Limit operator.
// When we ask 2 items, each partition will be asked 2 items from the LocalLimitExec.
val df2 = spark.createDataset(Seq(1, 2, 3)).limit(1)
df2.collect()
// The number of partitions of the LocalTableScanExec.
val parallelism = math.min(sparkContext.defaultParallelism, 3)
val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics
assert(metrics2.contains("numOutputRows"))
assert(metrics2("numOutputRows").value === parallelism * 1)
}
}

test("Filter metrics") {
Expand Down

0 comments on commit 86b4e42

Please sign in to comment.