Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,7 @@ package object dsl {

def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)

def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
val deserialized = logicalPlan.deserialize[T]
val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
Filter(condition, deserialized).serialize[T]
}
def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan)

def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
var maxOrdinal = -1
result foreach {
case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
case _ =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is actually a bug fix? Before we can only use a single BoundReference as result, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

}
if (maxOrdinal > children.length) {
return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
Expand Down Expand Up @@ -110,8 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
EmbedSerializerInFilter,
RemoveAliasOnlyProject) ::
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("OptimizeCodegen", Once,
Expand Down Expand Up @@ -206,15 +206,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
if d.outputObjAttr.dataType == s.inputObjAttr.dataType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
val objAttr =
Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId)
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)

case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)

// If there is a `SerializeFromObject` under typed filter and its input object type is same with
// the typed filter's deserializer, we can convert typed filter to normal filter without
// deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
case f @ TypedFilter(_, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObjectProducerChild(s.child))

// If there is a `DeserializeToObject` upon typed filter and its output object type is same with
// the typed filter's deserializer, we can convert typed filter to normal filter without
// deserialization in condition, and pull it up through `DeserializeToObject`.
// e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization,
// but `ds.filter(...).as[AnotherType].map(...)` can not be optimized.
case d @ DeserializeToObject(_, _, f: TypedFilter)
if d.outputObjAttr.dataType == f.deserializer.dataType =>
f.withObjectProducerChild(d.copy(child = f.child))
}
}

Expand Down Expand Up @@ -1645,54 +1663,30 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
}

/**
* Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
* [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
* the deserializer in filter condition to save the extra serialization at last.
* Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one,
* mering the filter functions into one conjunctive function.
*/
object EmbedSerializerInFilter extends Rule[LogicalPlan] {
object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject))
// SPARK-15632: Conceptually, filter operator should never introduce schema change. This
// optimization rule also relies on this assumption. However, Dataset typed filter operator
// does introduce schema changes in some cases. Thus, we only enable this optimization when
//
// 1. either input and output schemata are exactly the same, or
// 2. both input and output schemata are single-field schema and share the same type.
//
// The 2nd case is included because encoders for primitive types always have only a single
// field with hard-coded field name "value".
// TODO Cleans this up after fixing SPARK-15632.
if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) =>

val numObjects = condition.collect {
case a: Attribute if a == d.output.head => a
}.length

if (numObjects > 1) {
// If the filter condition references the object more than one times, we should not embed
// deserializer in it as the deserialization will happen many times and slow down the
// execution.
// TODO: we can still embed it if we can make sure subexpression elimination works here.
s
} else {
val newCondition = condition transform {
case a: Attribute if a == d.output.head => d.deserializer
}
val filter = Filter(newCondition, d.child)

// Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
val objAttrs = filter.output.zip(s.output).map { case (fout, sout) =>
Alias(fout, fout.name)(exprId = sout.exprId)
}
Project(objAttrs, filter)
}
}

def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = {
(lhs, rhs) match {
case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType
case _ => false
case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
if t1.deserializer.dataType == t2.deserializer.dataType =>
TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
}

private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
(func1, func2) match {
case (f1: FilterFunction[_], f2: FilterFunction[_]) =>
input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
f2.asInstanceOf[FilterFunction[Any]].call(input)
case (f1: FilterFunction[_], f2) =>
input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
f2.asInstanceOf[Any => Boolean](input)
case (f1, f2: FilterFunction[_]) =>
input => f1.asInstanceOf[Any => Boolean].apply(input) &&
f2.asInstanceOf[FilterFunction[Any]].call(input)
case (f1, f2) =>
input => f1.asInstanceOf[Any => Boolean].apply(input) &&
f2.asInstanceOf[Any => Boolean].apply(input)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.sql.catalyst.plans.logical

import scala.language.existentials

import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types._

object CatalystSerde {
Expand All @@ -45,13 +49,11 @@ object CatalystSerde {
*/
trait ObjectProducer extends LogicalPlan {
// The attribute that reference to the single object field this operator outputs.
protected def outputObjAttr: Attribute
def outputObjAttr: Attribute

override def output: Seq[Attribute] = outputObjAttr :: Nil

override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)

def outputObjectType: DataType = outputObjAttr.dataType
}

/**
Expand All @@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet

def inputObjectType: DataType = child.output.head.dataType
def inputObjAttr: Attribute = child.output.head
}

/**
Expand Down Expand Up @@ -167,6 +169,43 @@ case class MapElements(
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer

object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
}
}

/**
* A relation produced by applying `func` to each element of the `child` and filter them by the
* resulting boolean value.
*
* This is logically equal to a normal [[Filter]] operator whose condition expression is decoding
* the input row to object and apply the given function with decoded object. However we need the
* encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write
* optimizer rules.
*/
case class TypedFilter(
func: AnyRef,
deserializer: Expression,
child: LogicalPlan) extends UnaryNode {

override def output: Seq[Attribute] = child.output

def withObjectProducerChild(obj: LogicalPlan): Filter = {
assert(obj.output.length == 1)
Filter(typedCondition(obj.output.head), obj)
}

def typedCondition(input: Expression): Expression = {
val (funcClass, methodName) = func match {
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
case _ => classOf[Any => Boolean] -> "apply"
}
val funcObj = Literal.create(func, ObjectType(funcClass))
Invoke(funcObj, methodName, BooleanType, input :: Nil)
}
}

/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, TypedFilter}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.BooleanType

Expand All @@ -33,44 +32,91 @@ class TypedFilterOptimizationSuite extends PlanTest {
val batches =
Batch("EliminateSerialization", FixedPoint(50),
EliminateSerialization) ::
Batch("EmbedSerializerInFilter", FixedPoint(50),
EmbedSerializerInFilter) :: Nil
Batch("CombineTypedFilters", FixedPoint(50),
CombineTypedFilters) :: Nil
}

implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()

test("back to back filter") {
test("filter after serialize with the same object type") {
val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: (Int, Int)) => i._2 > 0
val f = (i: (Int, Int)) => i._1 > 0

val query = input.filter(f1).filter(f2).analyze
val query = input
.deserialize[(Int, Int)]
.serialize[(Int, Int)]
.filter(f).analyze

val optimized = Optimize.execute(query)

val expected = input.deserialize[(Int, Int)]
.where(callFunction(f1, BooleanType, 'obj))
.select('obj.as("obj"))
.where(callFunction(f2, BooleanType, 'obj))
val expected = input
.deserialize[(Int, Int)]
.where(callFunction(f, BooleanType, 'obj))
.serialize[(Int, Int)].analyze

comparePlans(optimized, expected)
}

// TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules
// for typed filters.
ignore("embed deserializer in typed filter condition if there is only one filter") {
test("filter after serialize with different object types") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: OtherTuple) => i._1 > 0

val query = input
.deserialize[(Int, Int)]
.serialize[(Int, Int)]
.filter(f).analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, query)
}

test("filter before deserialize with the same object type") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: (Int, Int)) => i._1 > 0

val query = input.filter(f).analyze
val query = input
.filter(f)
.deserialize[(Int, Int)]
.serialize[(Int, Int)].analyze

val optimized = Optimize.execute(query)

val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
val condition = callFunction(f, BooleanType, deserializer)
val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze
val expected = input
.deserialize[(Int, Int)]
.where(callFunction(f, BooleanType, 'obj))
.serialize[(Int, Int)].analyze

comparePlans(optimized, expected)
}

test("filter before deserialize with different object types") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: OtherTuple) => i._1 > 0

val query = input
.filter(f)
.deserialize[(Int, Int)]
.serialize[(Int, Int)].analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, query)
}

test("back to back filter with the same object type") {
val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: (Int, Int)) => i._2 > 0

val query = input.filter(f1).filter(f2).analyze
val optimized = Optimize.execute(query)
assert(optimized.collect { case t: TypedFilter => t }.length == 1)
}

test("back to back filter with different object types") {
val input = LocalRelation('_1.int, '_2.int)
val f1 = (i: (Int, Int)) => i._1 > 0
val f2 = (i: OtherTuple) => i._2 > 0

val query = input.filter(f1).filter(f2).analyze
val optimized = Optimize.execute(query)
assert(optimized.collect { case t: TypedFilter => t }.length == 2)
}
}
12 changes: 2 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1997,11 +1997,7 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: T => Boolean): Dataset[T] = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil)
val filter = Filter(condition, logicalPlan)
withTypedPlan(filter)
withTypedPlan(TypedFilter(func, logicalPlan))
}

/**
Expand All @@ -2014,11 +2010,7 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: FilterFunction[T]): Dataset[T] = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
val condition = Invoke(function, "call", BooleanType, deserializer :: Nil)
val filter = Filter(condition, logicalPlan)
withTypedPlan(filter)
withTypedPlan(TypedFilter(func, logicalPlan))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
case f: logical.TypedFilter =>
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
Expand Down
Loading