From dd8fbd98df96e2fd293906adc086651dbeb72074 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 19 May 2015 16:20:19 +0800 Subject: [PATCH 1/2] correct semanticEquals --- .../sql/catalyst/expressions/Expression.scala | 13 +++++++++---- .../spark/sql/catalyst/planning/patterns.scala | 13 ++++++------- .../spark/sql/execution/GeneratedAggregate.scala | 15 ++++++++------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8c1e4d74f9df..0b9f621fed7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] { * cosmetically (i.e. capitalization of names in attributes may be different). */ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = { + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2 + case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq) + case (i1, i2) => i1 == i2 + } + } val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq - elements1.length == elements2.length && elements1.zip(elements2).forall { - case (e1: Expression, e2: Expression) => e1 semanticEquals e2 - case (i1, i2) => i1 == i2 - } + checkSemantic(elements1, elements2) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1dd75a884630..b643efba54cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -143,11 +143,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = + val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) - }.toMap + } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. @@ -161,16 +161,15 @@ object PartialAggregation { // (Should we just turn `GetField` into a `NamedExpression`?) val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } namedGroupingExpressions - .find { case (k, v) => k semanticEquals trimmed } + .find { case (expr, _) => expr semanticEquals trimmed } .map(_._2.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + val partialComputation = namedGroupingExpressions.map(_._2) ++ + partialEvaluations.values.flatMap(_.partialEvaluations) - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) Some( (namedGroupingAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index af3791734d0c..b55a65266008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -214,18 +214,19 @@ case class GeneratedAggregate( }.toMap val namedGroups = groupingExpressions.zipWithIndex.map { - case (ne: NamedExpression, _) => (ne, ne) - case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + case (ne: NamedExpression, _) => (ne, ne.toAttribute) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute) } - val groupMap: Map[Expression, Attribute] = - namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap - // The set of expressions that produce the final output given the aggregation buffer and the // grouping expressions. val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) - case e: Expression if groupMap.contains(e) => groupMap(e) + case e: Expression => + namedGroups + .find { case (expr, _) => expr semanticEquals e } + .map(_._2) + .getOrElse(e) }) val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) @@ -265,7 +266,7 @@ case class GeneratedAggregate( val resultProjectionBuilder = newMutableProjection( resultExpressions, - (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") val joinedRow = new JoinedRow3 From 4daef887931bd662901c1c9af24a8cb66286ce1b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jun 2015 22:28:20 +0800 Subject: [PATCH 2/2] address comments --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 7 +++---- .../apache/spark/sql/execution/GeneratedAggregate.scala | 7 +++---- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b643efba54cf..3b6f8bfd9ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -160,10 +160,9 @@ object PartialAggregation { // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } - namedGroupingExpressions - .find { case (expr, _) => expr semanticEquals trimmed } - .map(_._2.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = namedGroupingExpressions.map(_._2) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b55a65266008..1c40a9209f6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -223,10 +223,9 @@ case class GeneratedAggregate( val resultExpressions = aggregateExpressions.map(_.transform { case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e)) case e: Expression => - namedGroups - .find { case (expr, _) => expr semanticEquals e } - .map(_._2) - .getOrElse(e) + namedGroups.collectFirst { + case (expr, attr) if expr semanticEquals e => attr + }.getOrElse(e) }) val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 14ecd4e9a77d..6898d584414b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - ignore("cartesian product join") { + test("cartesian product join") { checkAnswer( testData3.join(testData3), Row(1, null, 1, null) ::