Skip to content

Commit

Permalink
Fix test.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 24, 2016
1 parent 3d24f79 commit 76a3eaf
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ public boolean hasNext() throws IOException {
if (!shouldStop()) {
processNext();
}
return !currentRows.isEmpty();
boolean hasNext = !currentRows.isEmpty();
// If no more data available, releases resource if necessary.
if (!hasNext) {
releaseResource();
}
return hasNext;
}

public InternalRow next() {
Expand Down Expand Up @@ -91,4 +96,9 @@ protected void incPeakExecutionMemory(long size) {
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected abstract void processNext() throws IOException;

/**
* Releases resources if necessary. No-op in default.
*/
protected void releaseResource() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.execution.metric.SQLMetrics


Expand All @@ -28,11 +29,13 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
*/
case class LocalTableScanExec(
output: Seq[Attribute],
rows: Seq[InternalRow]) extends LeafExecNode {
rows: Seq[InternalRow]) extends LeafExecNode with CodegenSupport {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(rdd)

private val unsafeRows: Array[InternalRow] = {
if (rows.isEmpty) {
Array.empty
Expand All @@ -47,6 +50,22 @@ case class LocalTableScanExec(

private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism)

protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
val input = ctx.freshName("input")
// Right now, LocalTableScanExec is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
val row = ctx.freshName("row")
s"""
| while ($input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
| $numOutput.add(1);
| ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
| }
""".stripMargin
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
rdd.map { r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ case class HashAggregateExec(
""".stripMargin
}

ctx.addNewFunction("releaseResource", s"""
@Override
protected void releaseResource() {
$iterTerm.close();
if ($sorterTerm == null) {
$hashMapTerm.free();
}
}
""")

val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = SinglePartition
override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
override def executeCollect(): Array[InternalRow] = child match {
case e: Exchange => e.child.executeTake(limit)
case _ => child.executeTake(limit)
override def executeCollect(): Array[InternalRow] = {
child.collect {
case l: LocalLimitExec => l
}.head.child.executeTake(limit)
}

protected override def doExecute(): RDD[InternalRow] = {
Expand Down
16 changes: 12 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.LocalLimitExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -2684,11 +2685,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") {
val df = spark.range(1, 100, 1, numPartitions = 10).limit(1)
val localLimit = df.queryExecution.executedPlan.collect {
case l: LocalLimitExec => l
}
assert(localLimit.nonEmpty)
val numRecordsRead = spark.sparkContext.longAccumulator
spark.range(1, 100, 1, numPartitions = 10).map { x =>
numRecordsRead.add(1)
x
}.limit(1).queryExecution.toRdd.count()
localLimit.head.execute().mapPartitionsInternal { iter =>
iter.map { x =>
numRecordsRead.add(1)
x
}
}.count
assert(numRecordsRead.value === 10)
}

Expand Down

0 comments on commit 76a3eaf

Please sign in to comment.