Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,58 +18,47 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule

/**
* push down operations into [[CreateNamedStructLike]].
*/
object SimplifyCreateStructOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformExpressionsUp {
// push down field extraction
* Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions.
*/
object SimplifyExtractValueOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// One place where this optimization is invalid is an aggregation where the select
// list expression is a function of a grouping expression:
//
// SELECT struct(a,b).a FROM tbl GROUP BY struct(a,b)
//
// cannot be simplified to SELECT a FROM tbl GROUP BY struct(a,b). So just skip this
// optimization for Aggregates (although this misses some cases where the optimization
// can be made).
case a: Aggregate => a
case p => p.transformExpressionsUp {
// Remove redundant field extraction.
case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) =>
createNamedStructLike.valExprs(ordinal)
}
}
}

/**
* push down operations into [[CreateArray]].
*/
object SimplifyCreateArrayOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformExpressionsUp {
// push down field selection (array of structs)
case GetArrayStructFields(CreateArray(elems), field, ordinal, numFields, containsNull) =>
// instead f selecting the field on the entire array,
// select it from each member of the array.
// pushing down the operation this way open other optimizations opportunities
// (i.e. struct(...,x,...).x)
// Remove redundant array indexing.
case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
// Instead of selecting the field on the entire array, select it from each member
// of the array. Pushing down the operation this way may open other optimizations
// opportunities (i.e. struct(...,x,...).x)
CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))))
// push down item selection.

// Remove redundant map lookup.
case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) =>
// instead of creating the array and then selecting one row,
// remove array creation altgether.
// Instead of creating the array and then selecting one row, remove array creation
// altogether.
if (idx >= 0 && idx < elems.size) {
// valid index
elems(idx)
} else {
// out of bounds, mimic the runtime behavior and return null
Literal(null, ga.dataType)
}
}
}
}

/**
* push down operations into [[CreateMap]].
*/
object SimplifyCreateMapOps extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.transformExpressionsUp {
case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems)
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
EliminateSerialization,
RemoveRedundantAliases,
RemoveRedundantProject,
SimplifyCreateStructOps,
SimplifyCreateArrayOps,
SimplifyCreateMapOps,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
BooleanSimplification,
SimplifyConditionals,
SimplifyBinaryComparison,
SimplifyCreateStructOps,
SimplifyCreateArrayOps,
SimplifyCreateMapOps) :: Nil
SimplifyExtractValueOps) :: Nil
}

val idAtt = ('id).long.notNull
val nullableIdAtt = ('nullable_id).long

lazy val relation = LocalRelation(idAtt )
lazy val relation = LocalRelation(idAtt, nullableIdAtt)

test("explicit get from namedStruct") {
val query = relation
Expand Down Expand Up @@ -321,7 +320,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select(
CaseWhen(Seq(
(EqualTo(2L, 'id), ('id + 1L)),
// these two are possible matches, we can't tell untill runtime
// these two are possible matches, we can't tell until runtime
(EqualTo(2L, ('id + 1L)), ('id + 2L)),
(EqualTo(2L, 'id + 2L), Literal.create(null, LongType)),
// this is a definite match (two constants),
Expand All @@ -331,4 +330,50 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.analyze
comparePlans(Optimizer execute rel, expected)
}

test("SPARK-23500: Simplify complex ops that aren't at the plan root") {
val structRel = relation
.select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo")
.groupBy($"foo")("1").analyze
val structExpected = relation
.select('nullable_id as "foo")
.groupBy($"foo")("1").analyze
comparePlans(Optimizer execute structRel, structExpected)

// These tests must use nullable attributes from the base relation for the following reason:
// in the 'original' plans below, the Aggregate node produced by groupBy() has a
// nullable AttributeReference to a1, because both array indexing and map lookup are
// nullable expressions. After optimization, the same attribute is now non-nullable,
// but the AttributeReference is not updated to reflect this. In the 'expected' plans,
// the grouping expressions have the same nullability as the original attribute in the
// relation. If that attribute is non-nullable, the tests will fail as the plans will
// compare differently, so for these tests we must use a nullable attribute. See
// SPARK-23634.
val arrayRel = relation
.select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1")
.groupBy($"a1")("1").analyze
val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1").analyze
comparePlans(Optimizer execute arrayRel, arrayExpected)

val mapRel = relation
.select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1")
.groupBy($"m1")("1").analyze
val mapExpected = relation
.select('nullable_id as "m1")
.groupBy($"m1")("1").analyze
comparePlans(Optimizer execute mapRel, mapExpected)
Copy link
Member

@dongjoon-hyun dongjoon-hyun Mar 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@henryr .
Could you add more test cases mentioned today, for example, like the following? We need a test case for array, too.

    val structRel = relation.groupBy(
      CreateNamedStruct(Seq("att1", 'nullable_id)))(
      GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
    comparePlans(Optimizer execute structRel, structRel)

}

test("SPARK-23500: Ensure that aggregation expressions are not simplified") {
// Make sure that aggregation exprs are correctly ignored. Maps can't be used in
// grouping exprs so aren't tested here.
val structAggRel = relation.groupBy(
CreateNamedStruct(Seq("att1", 'nullable_id)))(
GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)).analyze
comparePlans(Optimizer execute structAggRel, structAggRel)

val arrayAggRel = relation.groupBy(
CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)).analyze
comparePlans(Optimizer execute arrayAggRel, arrayAggRel)
}
}