From 3ea0daf9114ec23c81d84f44ce94ee37aca5e55e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 9 Sep 2016 11:59:29 -0700 Subject: [PATCH 1/3] fix python udf in TakeOrderedAndProjectExec --- python/pyspark/sql/tests.py | 8 ++++++++ .../apache/spark/sql/execution/SparkStrategies.scala | 8 ++++---- .../scala/org/apache/spark/sql/execution/limit.scala | 12 ++++++------ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fd8e9cec3e0b..769e4540720e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -376,6 +376,14 @@ def test_udf_in_generate(self): row = df.select(explode(f(*df))).groupBy().sum().first() self.assertEqual(row[0], 10) + def test_udf_with_order_by_and_limit(self): + from pyspark.sql.functions import udf + my_copy = udf(lambda x: x, IntegerType()) + df = self.spark.range(10).orderBy("id") + res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) + res.explain(True) + self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c389593b4f76..3441ccf53b45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,22 +66,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ReturnAnswer(rootPlan) => rootPlan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil + execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil case logical.Limit( IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( - limit, order, Some(projectList), planLater(child)) :: Nil + limit, order, projectList, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil + execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil case logical.Limit( IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( - limit, order, Some(projectList), planLater(child)) :: Nil + limit, order, projectList, planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 781c01609542..562dfafb1142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -114,11 +114,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { case class TakeOrderedAndProjectExec( limit: Int, sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], + projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = { - projectList.map(_.map(_.toAttribute)).getOrElse(child.output) + projectList.map(_.toAttribute) } override def outputPartitioning: Partitioning = SinglePartition @@ -126,8 +126,8 @@ case class TakeOrderedAndProjectExec( override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) + if (AttributeSet(projectList) != child.outputSet) { + val proj = UnsafeProjection.create(projectList, child.output) data.map(r => proj(r).copy()) } else { data @@ -148,8 +148,8 @@ case class TakeOrderedAndProjectExec( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) + if (AttributeSet(projectList) != child.outputSet) { + val proj = UnsafeProjection.create(projectList, child.output) topK.map(r => proj(r)) } else { topK From 263c147edaf63981b29e6d3573f5658fc3c369ba Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 9 Sep 2016 13:15:44 -0700 Subject: [PATCH 2/3] fix tests --- .../spark/sql/execution/TakeOrderedAndProjectSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 3217e34bd8ad..7e317a4d8026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { checkThatPlansAgree( generateRandomInputData(), input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)), + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, @@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { generateRandomInputData(), input => noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)), + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, From 1e319d8f4ef1adf69b4fffa928bc1ac0c0f21805 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 12 Sep 2016 12:09:57 -0700 Subject: [PATCH 3/3] use Seq to compare --- .../src/main/scala/org/apache/spark/sql/execution/limit.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 562dfafb1142..01fbe5b7c2c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -126,7 +126,7 @@ case class TakeOrderedAndProjectExec( override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (AttributeSet(projectList) != child.outputSet) { + if (projectList != child.output) { val proj = UnsafeProjection.create(projectList, child.output) data.map(r => proj(r).copy()) } else { @@ -148,7 +148,7 @@ case class TakeOrderedAndProjectExec( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - if (AttributeSet(projectList) != child.outputSet) { + if (projectList != child.output) { val proj = UnsafeProjection.create(projectList, child.output) topK.map(r => proj(r)) } else {