Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4322][SQL] Enables struct fields as sub expressions of grouping fields #3248

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimAliases ::
TrimGroupingAliases ::
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
Expand Down Expand Up @@ -93,17 +93,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
/**
* Removes no-op Alias expressions from the plan.
*/
object TrimAliases extends Rule[LogicalPlan] {
object TrimGroupingAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(groups, aggs, child) =>
Aggregate(
groups.map {
_ transform {
case Alias(c, _) => c
}
},
aggs,
child)
Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
}
}

Expand All @@ -122,10 +115,15 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case e => e.children.forall(isValidAggregateExpression)
}

aggregateExprs.foreach { e =>
if (!isValidAggregateExpression(e)) {
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}
aggregateExprs.find { e =>
!isValidAggregateExpression(e.transform {
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An earlier version of catalyst actually did something similar to having GetField be a named expression. However, this complicated a lot of things. In particular, it made recursive schemas impossible and it is a little tricky to have non-leaf nodes that have to have expression ids.

That said, I'm not opposed to this idea in general, but I think it might be a fair amount of work to realize. We should possible evaluate a few options as the current situation also seems less than ideal.

I'd propose we merge this now to fix the bug and investigate after 1.2.

case Alias(g: GetField, _) => g
})
}.foreach { e =>
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}

aggregatePlan
Expand Down Expand Up @@ -328,4 +326,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,15 @@ object PartialAggregation {
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression if namedGroupingExpressions.contains(e) =>
namedGroupingExpressions(e).toAttribute

case e: Expression =>
// Should trim aliases around `GetField`s. These aliases are introduced while
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
// (Should we just turn `GetField` into a `NamedExpression`?)
namedGroupingExpressions
.get(e.transform { case Alias(g: GetField, _) => g })
.map(_.toAttribute)
.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation =
Expand Down Expand Up @@ -188,7 +195,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
val (joinPredicates, otherPredicates) =
val (joinPredicates, otherPredicates) =
condition.map(splitConjunctivePredicates).getOrElse(Nil).partition {
case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
(canEvaluate(l, right) && canEvaluate(r, left)) => true
Expand All @@ -203,7 +210,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
val rightKeys = joinKeys.map(_._2)

if (joinKeys.nonEmpty) {
logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
} else {
None
Expand Down
12 changes: 11 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
}

test("INTERSECT") {
test("INTERSECT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"),
(1, "a") ::
Expand Down Expand Up @@ -942,4 +942,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"),
(1 to 99).map(i => Seq(i)))
}

test("SPARK-4322 Grouping field with struct field as sub expression") {
jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1)
dropTempTable("data")

jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2)
dropTempTable("data")
}
}