Skip to content

Commit c19c785

Browse files
cloud-fanliancheng
authored andcommitted
[SQL] [MINOR] correct semanticEquals logic
It's a follow up of #6173, for expressions like `Coalesce` that have a `Seq[Expression]`, when we do semantic equal check for it, we need to do semantic equal check for all of its children. Also we can just use `Seq[(Expression, NamedExpression)]` instead of `Map[Expression, NamedExpression]` as we only search it with `find`. chenghao-intel, I agree that we probably never knows `semanticEquals` in a general way, but I think we have done that in `TreeNode`, so we can use similar logic. Then we can handle something like `Coalesce(children: Seq[Expression])` correctly. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6261 from cloud-fan/tmp and squashes the following commits: 4daef88 [Wenchen Fan] address comments dd8fbd9 [Wenchen Fan] correct semanticEquals
1 parent e428b3a commit c19c785

File tree

4 files changed

+25
-22
lines changed

4 files changed

+25
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] {
136136
* cosmetically (i.e. capitalization of names in attributes may be different).
137137
*/
138138
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
139+
def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
140+
elements1.length == elements2.length && elements1.zip(elements2).forall {
141+
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
142+
case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
143+
case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
144+
case (i1, i2) => i1 == i2
145+
}
146+
}
139147
val elements1 = this.productIterator.toSeq
140148
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
141-
elements1.length == elements2.length && elements1.zip(elements2).forall {
142-
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
143-
case (i1, i2) => i1 == i2
144-
}
149+
checkSemantic(elements1, elements2)
145150
}
146151

147152
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ object PartialAggregation {
143143
// We need to pass all grouping expressions though so the grouping can happen a second
144144
// time. However some of them might be unnamed so we alias them allowing them to be
145145
// referenced in the second aggregation.
146-
val namedGroupingExpressions: Map[Expression, NamedExpression] =
146+
val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
147147
groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
148148
case n: NamedExpression => (n, n)
149149
case other => (other, Alias(other, "PartialGroup")())
150-
}.toMap
150+
}
151151

152152
// Replace aggregations with a new expression that computes the result from the already
153153
// computed partial evaluations and grouping values.
@@ -160,17 +160,15 @@ object PartialAggregation {
160160
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
161161
// (Should we just turn `GetField` into a `NamedExpression`?)
162162
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
163-
namedGroupingExpressions
164-
.find { case (k, v) => k semanticEquals trimmed }
165-
.map(_._2.toAttribute)
166-
.getOrElse(e)
163+
namedGroupingExpressions.collectFirst {
164+
case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
165+
}.getOrElse(e)
167166
}).asInstanceOf[Seq[NamedExpression]]
168167

169-
val partialComputation =
170-
(namedGroupingExpressions.values ++
171-
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
168+
val partialComputation = namedGroupingExpressions.map(_._2) ++
169+
partialEvaluations.values.flatMap(_.partialEvaluations)
172170

173-
val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
171+
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
174172

175173
Some(
176174
(namedGroupingAttributes,

sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,18 +214,18 @@ case class GeneratedAggregate(
214214
}.toMap
215215

216216
val namedGroups = groupingExpressions.zipWithIndex.map {
217-
case (ne: NamedExpression, _) => (ne, ne)
218-
case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
217+
case (ne: NamedExpression, _) => (ne, ne.toAttribute)
218+
case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
219219
}
220220

221-
val groupMap: Map[Expression, Attribute] =
222-
namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
223-
224221
// The set of expressions that produce the final output given the aggregation buffer and the
225222
// grouping expressions.
226223
val resultExpressions = aggregateExpressions.map(_.transform {
227224
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
228-
case e: Expression if groupMap.contains(e) => groupMap(e)
225+
case e: Expression =>
226+
namedGroups.collectFirst {
227+
case (expr, attr) if expr semanticEquals e => attr
228+
}.getOrElse(e)
229229
})
230230

231231
val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
@@ -265,7 +265,7 @@ case class GeneratedAggregate(
265265
val resultProjectionBuilder =
266266
newMutableProjection(
267267
resultExpressions,
268-
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
268+
namedGroups.map(_._2) ++ computationSchema)
269269
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
270270

271271
val joinedRow = new JoinedRow3

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
697697
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
698698
}
699699

700-
ignore("cartesian product join") {
700+
test("cartesian product join") {
701701
checkAnswer(
702702
testData3.join(testData3),
703703
Row(1, null, 1, null) ::

0 commit comments

Comments
 (0)