diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7b9b21f41641..28818d805f0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -79,6 +80,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReorderJoin, OuterJoinElimination, PushPredicateThroughJoin, + PushPredicateThroughObjectConsumer, PushDownPredicate, LimitPushDown, ColumnPruning, @@ -1159,6 +1161,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // should not push predicates through sample, or will generate different results. case filter @ Filter(_, _: Sample) => filter + // should not push predicates through ObjectConsumer or will be handled by other rules + case filter @ Filter(_, _: ObjectConsumer) => filter + case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) @@ -1411,6 +1416,41 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Pushes down [[Filter]] operators through [[ObjectConsumer]]. + */ +object PushPredicateThroughObjectConsumer extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + + // Pushes down if the child is [[SerializeFromObject]] and the condition is a function + // applying to the serialized object. + case filter @ Filter(condition, sfo @ SerializeFromObject(serializer, grandchild)) => + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { + case cond @ Invoke(_, _, _, deserializer :: Nil, _) => + cond.deterministic && deserializer.dataType == sfo.inputObjectType + case _ => false + } + if (pushDown.nonEmpty) { + val newChild = SerializeFromObject( + serializer, + Filter( + pushDown.map { + case Invoke(function, name, returnType, _, propergateNull) => + Invoke(function, name, returnType, grandchild.output.head :: Nil, propergateNull) + }.reduceLeft(And), + grandchild)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } + } +} + /** * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fcc14a803bea..6f1472cbdcae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedDeserializer} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{BooleanType, IntegerType} class FilterPushdownSuite extends PlanTest { @@ -39,6 +41,7 @@ class FilterPushdownSuite extends PlanTest { PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, + PushPredicateThroughObjectConsumer, CollapseProject) :: Nil } @@ -980,4 +983,103 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } + + test("ObjectConsumer: can't push down expression filter through ObjectConsumer") { + implicit val tEncoder = + Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt, Encoders.scalaInt) + implicit val uEncoder = Encoders.scalaBoolean + val mapFunc: (Int, Int, Int) => Boolean = { (a, _, _) => a == 0 } + + // testRelation.map(mapFunc).where('value === true) + val originalQuery = { + val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation) + val mapped = MapElements( + mapFunc, + CatalystSerde.generateObjAttr[Boolean], + deserialized) + val serialized = CatalystSerde.serialize[Boolean](mapped) + Filter('value === true, serialized) + } + + val correctAnswer = originalQuery + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("ObjectConsumer: can push down function filter through SerializeFromObject") { + implicit val tEncoder = + Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt, Encoders.scalaInt) + implicit val uEncoder = Encoders.scalaBoolean + val mapFunc: (Int, Int, Int) => Boolean = { (a, _, _) => a == 0 } + val filterFunc: Boolean => Boolean = { _ == true } + + // testRelation.map(mapFunc).filter(filterFunc) + val originalQuery = { + val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation) + val mapped = MapElements( + mapFunc, + CatalystSerde.generateObjAttr[Boolean], + deserialized) + val serialized = CatalystSerde.serialize[Boolean](mapped) + + val deserializer = UnresolvedDeserializer(encoderFor[Boolean].deserializer) + val condition = callFunction(filterFunc, BooleanType, deserializer) + Filter(condition, serialized) + } + + val correctAnswer = { + val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation) + val mapped = MapElements( + mapFunc, + CatalystSerde.generateObjAttr[Boolean], + deserialized) + val filtered = Filter(callFunction(filterFunc, BooleanType, mapped.outputObjAttr), mapped) + CatalystSerde.serialize[Boolean](filtered) + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } + + test("ObjectConsumer: can push down complex filter through SerializeFromObject") { + implicit val tEncoder = + Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt, Encoders.scalaInt) + implicit val uEncoder = Encoders.scalaBoolean + val mapFunc: (Int, Int, Int) => Boolean = { (a, _, _) => a == 0 } + val filterFunc1: Boolean => Boolean = { _ == true } + val filterFunc2: Boolean => Boolean = { _ == false } + + // testRelation.map(mapFunc).filter(filterFunc1).where('value === true).filter(filterFunc2) + val originalQuery = { + val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation) + val mapped = MapElements( + mapFunc, + CatalystSerde.generateObjAttr[Boolean], + deserialized) + val serialized = CatalystSerde.serialize[Boolean](mapped) + + val deserializer = UnresolvedDeserializer(encoderFor[Boolean].deserializer) + val condition1 = callFunction(filterFunc1, BooleanType, deserializer) + val filtered1 = Filter(condition1, serialized) + + val where = Filter('value === true, filtered1) + + val condition2 = callFunction(filterFunc2, BooleanType, deserializer) + Filter(condition2, where) + } + + val correctAnswer = { + val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation) + val mapped = MapElements( + mapFunc, + CatalystSerde.generateObjAttr[Boolean], + deserialized) + val cond1 = callFunction(filterFunc1, BooleanType, mapped.outputObjAttr) + val cond2 = callFunction(filterFunc2, BooleanType, mapped.outputObjAttr) + val filtered = Filter(cond1 && cond2, mapped) + val serialized = CatalystSerde.serialize[Boolean](filtered) + Filter('value === true, serialized) + } + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f02a3141a050..9cfbc76b789e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -227,6 +227,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "b") } + test("map and filter") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( + ds.map(v => (v._1, v._2 + 1)).filter(_._1 == "b"), + ("b", 3)) + } + test("SPARK-15632: typed filter should preserve the underlying logical schema") { val ds = spark.range(10) val ds2 = ds.filter(_ > 3)