From 506590753dbe59c71bf576751d3c2a50a70d9099 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 12 Sep 2016 19:09:16 -0700 Subject: [PATCH 1/2] Add regression test. --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eac266cba55b..a2164f9ae3d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { data.selectExpr("`part.col1`", "`col.1`")) } } + + test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val numRecordsRead = spark.sparkContext.longAccumulator + spark.range(1, 100, 1, numPartitions = 10).map { x => + numRecordsRead.add(1) + x + }.limit(1).queryExecution.toRdd.count() + assert(numRecordsRead.value === 10) + } } From ac01f87f72e7df9c8c40a00822b98ebccb0ff1f6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 12 Sep 2016 19:09:24 -0700 Subject: [PATCH 2/2] Fix by adding missing limit. --- .../src/main/scala/org/apache/spark/sql/execution/limit.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 781c01609542..345fea6f567a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { + val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchange.prepareShuffleDependency( - child.execute(), child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } }