From 76a3eaf16ca68f1c0329cead866f8e2941285507 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 24 Oct 2016 03:38:18 +0000 Subject: [PATCH] Fix test. --- .../sql/execution/BufferedRowIterator.java | 12 ++++++++++- .../sql/execution/LocalTableScanExec.scala | 21 ++++++++++++++++++- .../aggregate/HashAggregateExec.scala | 9 ++++++++ .../apache/spark/sql/execution/limit.scala | 7 ++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++++---- 5 files changed, 56 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 2c5a59f9f7401..350582ca4d371 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -42,7 +42,12 @@ public boolean hasNext() throws IOException { if (!shouldStop()) { processNext(); } - return !currentRows.isEmpty(); + boolean hasNext = !currentRows.isEmpty(); + // If no more data available, releases resource if necessary. + if (!hasNext) { + releaseResource(); + } + return hasNext; } public InternalRow next() { @@ -91,4 +96,9 @@ protected void incPeakExecutionMemory(long size) { * After it's called, if currentRow is still null, it means no more rows left. */ protected abstract void processNext() throws IOException; + + /** + * Releases resources if necessary. No-op in default. + */ + protected void releaseResource() {} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35c62..51d8891f933ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.execution.metric.SQLMetrics @@ -28,11 +29,13 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class LocalTableScanExec( output: Seq[Attribute], - rows: Seq[InternalRow]) extends LeafExecNode { + rows: Seq[InternalRow]) extends LeafExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(rdd) + private val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty @@ -47,6 +50,22 @@ case class LocalTableScanExec( private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism) + protected override def doProduce(ctx: CodegenContext): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val input = ctx.freshName("input") + // Right now, LocalTableScanExec is only used when there is one upstream. + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val row = ctx.freshName("row") + s""" + | while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutput.add(1); + | ${consume(ctx, null, row).trim} + | if (shouldStop()) return; + | } + """.stripMargin + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") rdd.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 06199ef3e8243..6d96ed84c99e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -661,6 +661,15 @@ case class HashAggregateExec( """.stripMargin } + ctx.addNewFunction("releaseResource", s""" + @Override + protected void releaseResource() { + $iterTerm.close(); + if ($sorterTerm == null) { + $hashMapTerm.free(); + } + } + """) val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index ebd5790bbd14f..e5ae6a23f62a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -37,9 +37,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil - override def executeCollect(): Array[InternalRow] = child match { - case e: Exchange => e.child.executeTake(limit) - case _ => child.executeTake(limit) + override def executeCollect(): Array[InternalRow] = { + child.collect { + case l: LocalLimitExec => l + }.head.child.executeTake(limit) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 60978efddd7f8..f390578b9c96a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.LocalLimitExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2684,11 +2685,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val df = spark.range(1, 100, 1, numPartitions = 10).limit(1) + val localLimit = df.queryExecution.executedPlan.collect { + case l: LocalLimitExec => l + } + assert(localLimit.nonEmpty) val numRecordsRead = spark.sparkContext.longAccumulator - spark.range(1, 100, 1, numPartitions = 10).map { x => - numRecordsRead.add(1) - x - }.limit(1).queryExecution.toRdd.count() + localLimit.head.execute().mapPartitionsInternal { iter => + iter.map { x => + numRecordsRead.add(1) + x + } + }.count assert(numRecordsRead.value === 10) }