Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -79,6 +80,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReorderJoin,
OuterJoinElimination,
PushPredicateThroughJoin,
PushPredicateThroughObjectConsumer,
PushDownPredicate,
LimitPushDown,
ColumnPruning,
Expand Down Expand Up @@ -1159,6 +1161,9 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// should not push predicates through sample, or will generate different results.
case filter @ Filter(_, _: Sample) => filter

// should not push predicates through ObjectConsumer or will be handled by other rules
case filter @ Filter(_, _: ObjectConsumer) => filter

case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, u.child) { predicate =>
u.withNewChildren(Seq(Filter(predicate, u.child)))
Expand Down Expand Up @@ -1411,6 +1416,41 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Pushes down [[Filter]] operators through [[ObjectConsumer]].
*/
object PushPredicateThroughObjectConsumer extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {

// Pushes down if the child is [[SerializeFromObject]] and the condition is a function
// applying to the serialized object.
case filter @ Filter(condition, sfo @ SerializeFromObject(serializer, grandchild)) =>
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
case cond @ Invoke(_, _, _, deserializer :: Nil, _) =>
cond.deterministic && deserializer.dataType == sfo.inputObjectType
case _ => false
}
if (pushDown.nonEmpty) {
val newChild = SerializeFromObject(
serializer,
Filter(
pushDown.map {
case Invoke(function, name, returnType, _, propergateNull) =>
Invoke(function, name, returnType, grandchild.output.head :: Nil, propergateNull)
}.reduceLeft(And),
grandchild))
if (stayUp.nonEmpty) {
Filter(stayUp.reduceLeft(And), newChild)
} else {
newChild
}
} else {
filter
}
}
}

/**
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedDeserializer}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{BooleanType, IntegerType}

class FilterPushdownSuite extends PlanTest {

Expand All @@ -39,6 +41,7 @@ class FilterPushdownSuite extends PlanTest {
PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
PushPredicateThroughObjectConsumer,
CollapseProject) :: Nil
}

Expand Down Expand Up @@ -980,4 +983,103 @@ class FilterPushdownSuite extends PlanTest {

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer)
}

test("ObjectConsumer: can't push down expression filter through ObjectConsumer") {
implicit val tEncoder =
Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt, Encoders.scalaInt)
implicit val uEncoder = Encoders.scalaBoolean
val mapFunc: (Int, Int, Int) => Boolean = { (a, _, _) => a == 0 }

// testRelation.map(mapFunc).where('value === true)
val originalQuery = {
val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation)
val mapped = MapElements(
mapFunc,
CatalystSerde.generateObjAttr[Boolean],
deserialized)
val serialized = CatalystSerde.serialize[Boolean](mapped)
Filter('value === true, serialized)
}

val correctAnswer = originalQuery

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("ObjectConsumer: can push down function filter through SerializeFromObject") {
implicit val tEncoder =
Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt, Encoders.scalaInt)
implicit val uEncoder = Encoders.scalaBoolean
val mapFunc: (Int, Int, Int) => Boolean = { (a, _, _) => a == 0 }
val filterFunc: Boolean => Boolean = { _ == true }

// testRelation.map(mapFunc).filter(filterFunc)
val originalQuery = {
val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation)
val mapped = MapElements(
mapFunc,
CatalystSerde.generateObjAttr[Boolean],
deserialized)
val serialized = CatalystSerde.serialize[Boolean](mapped)

val deserializer = UnresolvedDeserializer(encoderFor[Boolean].deserializer)
val condition = callFunction(filterFunc, BooleanType, deserializer)
Filter(condition, serialized)
}

val correctAnswer = {
val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation)
val mapped = MapElements(
mapFunc,
CatalystSerde.generateObjAttr[Boolean],
deserialized)
val filtered = Filter(callFunction(filterFunc, BooleanType, mapped.outputObjAttr), mapped)
CatalystSerde.serialize[Boolean](filtered)
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}

test("ObjectConsumer: can push down complex filter through SerializeFromObject") {
implicit val tEncoder =
Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt, Encoders.scalaInt)
implicit val uEncoder = Encoders.scalaBoolean
val mapFunc: (Int, Int, Int) => Boolean = { (a, _, _) => a == 0 }
val filterFunc1: Boolean => Boolean = { _ == true }
val filterFunc2: Boolean => Boolean = { _ == false }

// testRelation.map(mapFunc).filter(filterFunc1).where('value === true).filter(filterFunc2)
val originalQuery = {
val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation)
val mapped = MapElements(
mapFunc,
CatalystSerde.generateObjAttr[Boolean],
deserialized)
val serialized = CatalystSerde.serialize[Boolean](mapped)

val deserializer = UnresolvedDeserializer(encoderFor[Boolean].deserializer)
val condition1 = callFunction(filterFunc1, BooleanType, deserializer)
val filtered1 = Filter(condition1, serialized)

val where = Filter('value === true, filtered1)

val condition2 = callFunction(filterFunc2, BooleanType, deserializer)
Filter(condition2, where)
}

val correctAnswer = {
val deserialized = CatalystSerde.deserialize[(Int, Int, Int)](testRelation)
val mapped = MapElements(
mapFunc,
CatalystSerde.generateObjAttr[Boolean],
deserialized)
val cond1 = callFunction(filterFunc1, BooleanType, mapped.outputObjAttr)
val cond2 = callFunction(filterFunc2, BooleanType, mapped.outputObjAttr)
val filtered = Filter(cond1 && cond2, mapped)
val serialized = CatalystSerde.serialize[Boolean](filtered)
Filter('value === true, serialized)
}

comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"b")
}

test("map and filter") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
ds.map(v => (v._1, v._2 + 1)).filter(_._1 == "b"),
("b", 3))
}

test("SPARK-15632: typed filter should preserve the underlying logical schema") {
val ds = spark.range(10)
val ds2 = ds.filter(_ > 3)
Expand Down