Skip to content

Commit 477d6bd

Browse files
henryrgatorsmile
authored andcommitted
[SPARK-23500][SQL] Fix complex type simplification rules to apply to entire plan
## What changes were proposed in this pull request? Complex type simplification optimizer rules were not applied to the entire plan, just the expressions reachable from the root node. This patch fixes the rules to transform the entire plan. ## How was this patch tested? New unit test + ran sql / core tests. Author: Henry Robinson <henry@apache.org> Author: Henry Robinson <henry@cloudera.com> Closes #20687 from henryr/spark-25000.
1 parent 2c4b996 commit 477d6bd

File tree

3 files changed

+76
-44
lines changed

3 files changed

+76
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,47 @@
1818
package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
21+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323

2424
/**
25-
* push down operations into [[CreateNamedStructLike]].
26-
*/
27-
object SimplifyCreateStructOps extends Rule[LogicalPlan] {
28-
override def apply(plan: LogicalPlan): LogicalPlan = {
29-
plan.transformExpressionsUp {
30-
// push down field extraction
25+
* Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions.
26+
*/
27+
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
28+
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
29+
// One place where this optimization is invalid is an aggregation where the select
30+
// list expression is a function of a grouping expression:
31+
//
32+
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
33+
//
34+
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
35+
// optimization for Aggregates (although this misses some cases where the optimization
36+
// can be made).
37+
case a: Aggregate => a
38+
case p => p.transformExpressionsUp {
39+
// Remove redundant field extraction.
3140
case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) =>
3241
createNamedStructLike.valExprs(ordinal)
33-
}
34-
}
35-
}
3642

37-
/**
38-
* push down operations into [[CreateArray]].
39-
*/
40-
object SimplifyCreateArrayOps extends Rule[LogicalPlan] {
41-
override def apply(plan: LogicalPlan): LogicalPlan = {
42-
plan.transformExpressionsUp {
43-
// push down field selection (array of structs)
44-
case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) =>
45-
// instead f selecting the field on the entire array,
46-
// select it from each member of the array.
47-
// pushing down the operation this way open other optimizations opportunities
48-
// (i.e. struct(...,x,...).x)
43+
// Remove redundant array indexing.
44+
case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
45+
// Instead of selecting the field on the entire array, select it from each member
46+
// of the array. Pushing down the operation this way may open other optimizations
47+
// opportunities (i.e. struct(...,x,...).x)
4948
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))))
50-
// push down item selection.
49+
50+
// Remove redundant map lookup.
5151
case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) =>
52-
// instead of creating the array and then selecting one row,
53-
// remove array creation altgether.
52+
// Instead of creating the array and then selecting one row, remove array creation
53+
// altogether.
5454
if (idx >= 0 && idx < elems.size) {
5555
// valid index
5656
elems(idx)
5757
} else {
5858
// out of bounds, mimic the runtime behavior and return null
5959
Literal(null, ga.dataType)
6060
}
61-
}
62-
}
63-
}
64-
65-
/**
66-
* push down operations into [[CreateMap]].
67-
*/
68-
object SimplifyCreateMapOps extends Rule[LogicalPlan] {
69-
override def apply(plan: LogicalPlan): LogicalPlan = {
70-
plan.transformExpressionsUp {
7161
case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems)
7262
}
7363
}
7464
}
75-

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
8585
EliminateSerialization,
8686
RemoveRedundantAliases,
8787
RemoveRedundantProject,
88-
SimplifyCreateStructOps,
89-
SimplifyCreateArrayOps,
90-
SimplifyCreateMapOps,
88+
SimplifyExtractValueOps,
9189
CombineConcats) ++
9290
extendedOperatorOptimizationRules
9391

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
4444
BooleanSimplification,
4545
SimplifyConditionals,
4646
SimplifyBinaryComparison,
47-
SimplifyCreateStructOps,
48-
SimplifyCreateArrayOps,
49-
SimplifyCreateMapOps) :: Nil
47+
SimplifyExtractValueOps) :: Nil
5048
}
5149

5250
val idAtt = ('id).long.notNull
51+
val nullableIdAtt = ('nullable_id).long
5352

54-
lazy val relation = LocalRelation(idAtt )
53+
lazy val relation = LocalRelation(idAtt, nullableIdAtt)
5554

5655
test("explicit get from namedStruct") {
5756
val query = relation
@@ -321,7 +320,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
321320
.select(
322321
CaseWhen(Seq(
323322
(EqualTo(2L, 'id), ('id + 1L)),
324-
// these two are possible matches, we can't tell untill runtime
323+
// these two are possible matches, we can't tell until runtime
325324
(EqualTo(2L, ('id + 1L)), ('id + 2L)),
326325
(EqualTo(2L, 'id + 2L), Literal.create(null, LongType)),
327326
// this is a definite match (two constants),
@@ -331,4 +330,50 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
331330
.analyze
332331
comparePlans(Optimizer execute rel, expected)
333332
}
333+
334+
test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
335+
val structRel = relation
336+
.select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
337+
.groupBy($"foo")("1").analyze
338+
val structExpected = relation
339+
.select('nullable_id as "foo")
340+
.groupBy($"foo")("1").analyze
341+
comparePlans(Optimizer execute structRel, structExpected)
342+
343+
// These tests must use nullable attributes from the base relation for the following reason:
344+
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
345+
// nullable AttributeReference to a1, because both array indexing and map lookup are
346+
// nullable expressions. After optimization, the same attribute is now non-nullable,
347+
// but the AttributeReference is not updated to reflect this. In the 'expected' plans,
348+
// the grouping expressions have the same nullability as the original attribute in the
349+
// relation. If that attribute is non-nullable, the tests will fail as the plans will
350+
// compare differently, so for these tests we must use a nullable attribute. See
351+
// SPARK-23634.
352+
val arrayRel = relation
353+
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
354+
.groupBy($"a1")("1").analyze
355+
val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze
356+
comparePlans(Optimizer execute arrayRel, arrayExpected)
357+
358+
val mapRel = relation
359+
.select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
360+
.groupBy($"m1")("1").analyze
361+
val mapExpected = relation
362+
.select('nullable_id as "m1")
363+
.groupBy($"m1")("1").analyze
364+
comparePlans(Optimizer execute mapRel, mapExpected)
365+
}
366+
367+
test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
368+
// Make sure that aggregation exprs are correctly ignored. Maps can't be used in
369+
// grouping exprs so aren't tested here.
370+
val structAggRel = relation.groupBy(
371+
CreateNamedStruct(Seq("att1", 'nullable_id)))(
372+
GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
373+
comparePlans(Optimizer execute structAggRel, structAggRel)
374+
375+
val arrayAggRel = relation.groupBy(
376+
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze
377+
comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
378+
}
334379
}

0 commit comments

Comments
 (0)