Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] {
* cosmetically (i.e. capitalization of names in attributes may be different).
*/
def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && {
def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
elements1.length == elements2.length && elements1.zip(elements2).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
case (i1, i2) => i1 == i2
}
}
val elements1 = this.productIterator.toSeq
val elements2 = other.asInstanceOf[Product].productIterator.toSeq
elements1.length == elements2.length && elements1.zip(elements2).forall {
case (e1: Expression, e2: Expression) => e1 semanticEquals e2
case (i1, i2) => i1 == i2
}
checkSemantic(elements1, elements2)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ object PartialAggregation {
// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
// referenced in the second aggregation.
val namedGroupingExpressions: Map[Expression, NamedExpression] =
val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for expressions like Coalesce whose children order is sensitive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not about the order, it's just we don't need a map here as we search it with find, not get.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Makes sense.

groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
case n: NamedExpression => (n, n)
case other => (other, Alias(other, "PartialGroup")())
}.toMap
}

// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
Expand All @@ -160,17 +160,15 @@ object PartialAggregation {
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions
.find { case (k, v) => k semanticEquals trimmed }
.map(_._2.toAttribute)
.getOrElse(e)
namedGroupingExpressions.collectFirst {
case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation =
(namedGroupingExpressions.values ++
partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
val partialComputation = namedGroupingExpressions.map(_._2) ++
partialEvaluations.values.flatMap(_.partialEvaluations)

val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)

Some(
(namedGroupingAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,18 @@ case class GeneratedAggregate(
}.toMap

val namedGroups = groupingExpressions.zipWithIndex.map {
case (ne: NamedExpression, _) => (ne, ne)
case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
case (ne: NamedExpression, _) => (ne, ne.toAttribute)
case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
}

val groupMap: Map[Expression, Attribute] =
namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap

// The set of expressions that produce the final output given the aggregation buffer and the
// grouping expressions.
val resultExpressions = aggregateExpressions.map(_.transform {
case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
case e: Expression if groupMap.contains(e) => groupMap(e)
case e: Expression =>
namedGroups.collectFirst {
case (expr, attr) if expr semanticEquals e => attr
}.getOrElse(e)
})

val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
Expand Down Expand Up @@ -265,7 +265,7 @@ case class GeneratedAggregate(
val resultProjectionBuilder =
newMutableProjection(
resultExpressions,
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
namedGroups.map(_._2) ++ computationSchema)
log.info(s"Result Projection: ${resultExpressions.mkString(",")}")

val joinedRow = new JoinedRow3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}

ignore("cartesian product join") {
test("cartesian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Expand Down