Skip to content

Commit 8da4314

Browse files
cloud-fanliancheng
authored andcommitted
[SPARK-16134][SQL] optimizer rules for typed filter
## What changes were proposed in this pull request? This PR adds 3 optimizer rules for typed filter: 1. push typed filter down through `SerializeFromObject` and eliminate the deserialization in filter condition. 2. pull typed filter up through `SerializeFromObject` and eliminate the deserialization in filter condition. 3. combine adjacent typed filters and share the deserialized object among all the condition expressions. This PR also adds `TypedFilter` logical plan, to separate it from normal filter, so that the concept is more clear and it's easier to write optimizer rules. ## How was this patch tested? `TypedFilterOptimizationSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13846 from cloud-fan/filter. (cherry picked from commit d063898) Signed-off-by: Cheng Lian <lian@databricks.com>
1 parent 011befd commit 8da4314

File tree

8 files changed

+162
-91
lines changed

8 files changed

+162
-91
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,7 @@ package object dsl {
293293

294294
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
295295

296-
def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
297-
val deserialized = logicalPlan.deserialize[T]
298-
val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
299-
Filter(condition, deserialized).serialize[T]
300-
}
296+
def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan)
301297

302298
def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
303299

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
4545
var maxOrdinal = -1
4646
result foreach {
4747
case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
48+
case _ =>
4849
}
4950
if (maxOrdinal > children.length) {
5051
return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
2121
import scala.collection.immutable.HashSet
2222
import scala.collection.mutable.ArrayBuffer
2323

24+
import org.apache.spark.api.java.function.FilterFunction
2425
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
2526
import org.apache.spark.sql.catalyst.analysis._
2627
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
@@ -109,8 +110,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
109110
Batch("Decimal Optimizations", fixedPoint,
110111
DecimalAggregates) ::
111112
Batch("Typed Filter Optimization", fixedPoint,
112-
EmbedSerializerInFilter,
113-
RemoveAliasOnlyProject) ::
113+
CombineTypedFilters) ::
114114
Batch("LocalRelation", fixedPoint,
115115
ConvertToLocalRelation) ::
116116
Batch("OptimizeCodegen", Once,
@@ -205,15 +205,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
205205
object EliminateSerialization extends Rule[LogicalPlan] {
206206
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
207207
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
208-
if d.outputObjectType == s.inputObjectType =>
208+
if d.outputObjAttr.dataType == s.inputObjAttr.dataType =>
209209
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
210210
// We will remove it later in RemoveAliasOnlyProject rule.
211-
val objAttr =
212-
Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId)
211+
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
213212
Project(objAttr :: Nil, s.child)
213+
214214
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
215-
if a.deserializer.dataType == s.inputObjectType =>
215+
if a.deserializer.dataType == s.inputObjAttr.dataType =>
216216
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
217+
218+
// If there is a `SerializeFromObject` under typed filter and its input object type is same with
219+
// the typed filter's deserializer, we can convert typed filter to normal filter without
220+
// deserialization in condition, and push it down through `SerializeFromObject`.
221+
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
222+
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
223+
case f @ TypedFilter(_, _, s: SerializeFromObject)
224+
if f.deserializer.dataType == s.inputObjAttr.dataType =>
225+
s.copy(child = f.withObjectProducerChild(s.child))
226+
227+
// If there is a `DeserializeToObject` upon typed filter and its output object type is same with
228+
// the typed filter's deserializer, we can convert typed filter to normal filter without
229+
// deserialization in condition, and pull it up through `DeserializeToObject`.
230+
// e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization,
231+
// but `ds.filter(...).as[AnotherType].map(...)` can not be optimized.
232+
case d @ DeserializeToObject(_, _, f: TypedFilter)
233+
if d.outputObjAttr.dataType == f.deserializer.dataType =>
234+
f.withObjectProducerChild(d.copy(child = f.child))
217235
}
218236
}
219237

@@ -1606,54 +1624,30 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
16061624
}
16071625

16081626
/**
1609-
* Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
1610-
* [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
1611-
* the deserializer in filter condition to save the extra serialization at last.
1627+
* Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one,
1628+
* mering the filter functions into one conjunctive function.
16121629
*/
1613-
object EmbedSerializerInFilter extends Rule[LogicalPlan] {
1630+
object CombineTypedFilters extends Rule[LogicalPlan] {
16141631
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1615-
case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject))
1616-
// SPARK-15632: Conceptually, filter operator should never introduce schema change. This
1617-
// optimization rule also relies on this assumption. However, Dataset typed filter operator
1618-
// does introduce schema changes in some cases. Thus, we only enable this optimization when
1619-
//
1620-
// 1. either input and output schemata are exactly the same, or
1621-
// 2. both input and output schemata are single-field schema and share the same type.
1622-
//
1623-
// The 2nd case is included because encoders for primitive types always have only a single
1624-
// field with hard-coded field name "value".
1625-
// TODO Cleans this up after fixing SPARK-15632.
1626-
if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) =>
1627-
1628-
val numObjects = condition.collect {
1629-
case a: Attribute if a == d.output.head => a
1630-
}.length
1631-
1632-
if (numObjects > 1) {
1633-
// If the filter condition references the object more than one times, we should not embed
1634-
// deserializer in it as the deserialization will happen many times and slow down the
1635-
// execution.
1636-
// TODO: we can still embed it if we can make sure subexpression elimination works here.
1637-
s
1638-
} else {
1639-
val newCondition = condition transform {
1640-
case a: Attribute if a == d.output.head => d.deserializer
1641-
}
1642-
val filter = Filter(newCondition, d.child)
1643-
1644-
// Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`.
1645-
// We will remove it later in RemoveAliasOnlyProject rule.
1646-
val objAttrs = filter.output.zip(s.output).map { case (fout, sout) =>
1647-
Alias(fout, fout.name)(exprId = sout.exprId)
1648-
}
1649-
Project(objAttrs, filter)
1650-
}
1651-
}
1652-
1653-
def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = {
1654-
(lhs, rhs) match {
1655-
case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType
1656-
case _ => false
1632+
case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
1633+
if t1.deserializer.dataType == t2.deserializer.dataType =>
1634+
TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
1635+
}
1636+
1637+
private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
1638+
(func1, func2) match {
1639+
case (f1: FilterFunction[_], f2: FilterFunction[_]) =>
1640+
input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
1641+
f2.asInstanceOf[FilterFunction[Any]].call(input)
1642+
case (f1: FilterFunction[_], f2) =>
1643+
input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
1644+
f2.asInstanceOf[Any => Boolean](input)
1645+
case (f1, f2: FilterFunction[_]) =>
1646+
input => f1.asInstanceOf[Any => Boolean].apply(input) &&
1647+
f2.asInstanceOf[FilterFunction[Any]].call(input)
1648+
case (f1, f2) =>
1649+
input => f1.asInstanceOf[Any => Boolean].apply(input) &&
1650+
f2.asInstanceOf[Any => Boolean].apply(input)
16571651
}
16581652
}
16591653
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20+
import scala.language.existentials
21+
22+
import org.apache.spark.api.java.function.FilterFunction
2023
import org.apache.spark.broadcast.Broadcast
2124
import org.apache.spark.sql.{Encoder, Row}
2225
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
2326
import org.apache.spark.sql.catalyst.encoders._
2427
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
2529
import org.apache.spark.sql.types._
2630

2731
object CatalystSerde {
@@ -45,13 +49,11 @@ object CatalystSerde {
4549
*/
4650
trait ObjectProducer extends LogicalPlan {
4751
// The attribute that reference to the single object field this operator outputs.
48-
protected def outputObjAttr: Attribute
52+
def outputObjAttr: Attribute
4953

5054
override def output: Seq[Attribute] = outputObjAttr :: Nil
5155

5256
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
53-
54-
def outputObjectType: DataType = outputObjAttr.dataType
5557
}
5658

5759
/**
@@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode {
6466
// This operator always need all columns of its child, even it doesn't reference to.
6567
override def references: AttributeSet = child.outputSet
6668

67-
def inputObjectType: DataType = child.output.head.dataType
69+
def inputObjAttr: Attribute = child.output.head
6870
}
6971

7072
/**
@@ -167,6 +169,43 @@ case class MapElements(
167169
outputObjAttr: Attribute,
168170
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
169171

172+
object TypedFilter {
173+
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
174+
TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
175+
}
176+
}
177+
178+
/**
179+
* A relation produced by applying `func` to each element of the `child` and filter them by the
180+
* resulting boolean value.
181+
*
182+
* This is logically equal to a normal [[Filter]] operator whose condition expression is decoding
183+
* the input row to object and apply the given function with decoded object. However we need the
184+
* encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write
185+
* optimizer rules.
186+
*/
187+
case class TypedFilter(
188+
func: AnyRef,
189+
deserializer: Expression,
190+
child: LogicalPlan) extends UnaryNode {
191+
192+
override def output: Seq[Attribute] = child.output
193+
194+
def withObjectProducerChild(obj: LogicalPlan): Filter = {
195+
assert(obj.output.length == 1)
196+
Filter(typedCondition(obj.output.head), obj)
197+
}
198+
199+
def typedCondition(input: Expression): Expression = {
200+
val (funcClass, methodName) = func match {
201+
case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
202+
case _ => classOf[Any => Boolean] -> "apply"
203+
}
204+
val funcObj = Literal.create(func, ObjectType(funcClass))
205+
Invoke(funcObj, methodName, BooleanType, input :: Nil)
206+
}
207+
}
208+
170209
/** Factory for constructing new `AppendColumn` nodes. */
171210
object AppendColumns {
172211
def apply[T : Encoder, U : Encoder](

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.reflect.runtime.universe.TypeTag
2121

22-
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
2322
import org.apache.spark.sql.catalyst.dsl.expressions._
2423
import org.apache.spark.sql.catalyst.dsl.plans._
25-
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
24+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2625
import org.apache.spark.sql.catalyst.plans.PlanTest
27-
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
26+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, TypedFilter}
2827
import org.apache.spark.sql.catalyst.rules.RuleExecutor
2928
import org.apache.spark.sql.types.BooleanType
3029

@@ -33,44 +32,91 @@ class TypedFilterOptimizationSuite extends PlanTest {
3332
val batches =
3433
Batch("EliminateSerialization", FixedPoint(50),
3534
EliminateSerialization) ::
36-
Batch("EmbedSerializerInFilter", FixedPoint(50),
37-
EmbedSerializerInFilter) :: Nil
35+
Batch("CombineTypedFilters", FixedPoint(50),
36+
CombineTypedFilters) :: Nil
3837
}
3938

4039
implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
4140

42-
test("back to back filter") {
41+
test("filter after serialize with the same object type") {
4342
val input = LocalRelation('_1.int, '_2.int)
44-
val f1 = (i: (Int, Int)) => i._1 > 0
45-
val f2 = (i: (Int, Int)) => i._2 > 0
43+
val f = (i: (Int, Int)) => i._1 > 0
4644

47-
val query = input.filter(f1).filter(f2).analyze
45+
val query = input
46+
.deserialize[(Int, Int)]
47+
.serialize[(Int, Int)]
48+
.filter(f).analyze
4849

4950
val optimized = Optimize.execute(query)
5051

51-
val expected = input.deserialize[(Int, Int)]
52-
.where(callFunction(f1, BooleanType, 'obj))
53-
.select('obj.as("obj"))
54-
.where(callFunction(f2, BooleanType, 'obj))
52+
val expected = input
53+
.deserialize[(Int, Int)]
54+
.where(callFunction(f, BooleanType, 'obj))
5555
.serialize[(Int, Int)].analyze
5656

5757
comparePlans(optimized, expected)
5858
}
5959

60-
// TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules
61-
// for typed filters.
62-
ignore("embed deserializer in typed filter condition if there is only one filter") {
60+
test("filter after serialize with different object types") {
61+
val input = LocalRelation('_1.int, '_2.int)
62+
val f = (i: OtherTuple) => i._1 > 0
63+
64+
val query = input
65+
.deserialize[(Int, Int)]
66+
.serialize[(Int, Int)]
67+
.filter(f).analyze
68+
val optimized = Optimize.execute(query)
69+
comparePlans(optimized, query)
70+
}
71+
72+
test("filter before deserialize with the same object type") {
6373
val input = LocalRelation('_1.int, '_2.int)
6474
val f = (i: (Int, Int)) => i._1 > 0
6575

66-
val query = input.filter(f).analyze
76+
val query = input
77+
.filter(f)
78+
.deserialize[(Int, Int)]
79+
.serialize[(Int, Int)].analyze
6780

6881
val optimized = Optimize.execute(query)
6982

70-
val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
71-
val condition = callFunction(f, BooleanType, deserializer)
72-
val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze
83+
val expected = input
84+
.deserialize[(Int, Int)]
85+
.where(callFunction(f, BooleanType, 'obj))
86+
.serialize[(Int, Int)].analyze
7387

7488
comparePlans(optimized, expected)
7589
}
90+
91+
test("filter before deserialize with different object types") {
92+
val input = LocalRelation('_1.int, '_2.int)
93+
val f = (i: OtherTuple) => i._1 > 0
94+
95+
val query = input
96+
.filter(f)
97+
.deserialize[(Int, Int)]
98+
.serialize[(Int, Int)].analyze
99+
val optimized = Optimize.execute(query)
100+
comparePlans(optimized, query)
101+
}
102+
103+
test("back to back filter with the same object type") {
104+
val input = LocalRelation('_1.int, '_2.int)
105+
val f1 = (i: (Int, Int)) => i._1 > 0
106+
val f2 = (i: (Int, Int)) => i._2 > 0
107+
108+
val query = input.filter(f1).filter(f2).analyze
109+
val optimized = Optimize.execute(query)
110+
assert(optimized.collect { case t: TypedFilter => t }.length == 1)
111+
}
112+
113+
test("back to back filter with different object types") {
114+
val input = LocalRelation('_1.int, '_2.int)
115+
val f1 = (i: (Int, Int)) => i._1 > 0
116+
val f2 = (i: OtherTuple) => i._2 > 0
117+
118+
val query = input.filter(f1).filter(f2).analyze
119+
val optimized = Optimize.execute(query)
120+
assert(optimized.collect { case t: TypedFilter => t }.length == 2)
121+
}
76122
}

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,11 +1964,7 @@ class Dataset[T] private[sql](
19641964
*/
19651965
@Experimental
19661966
def filter(func: T => Boolean): Dataset[T] = {
1967-
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
1968-
val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
1969-
val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil)
1970-
val filter = Filter(condition, logicalPlan)
1971-
withTypedPlan(filter)
1967+
withTypedPlan(TypedFilter(func, logicalPlan))
19721968
}
19731969

19741970
/**
@@ -1981,11 +1977,7 @@ class Dataset[T] private[sql](
19811977
*/
19821978
@Experimental
19831979
def filter(func: FilterFunction[T]): Dataset[T] = {
1984-
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
1985-
val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
1986-
val condition = Invoke(function, "call", BooleanType, deserializer :: Nil)
1987-
val filter = Filter(condition, logicalPlan)
1988-
withTypedPlan(filter)
1980+
withTypedPlan(TypedFilter(func, logicalPlan))
19891981
}
19901982

19911983
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
385385
execution.ProjectExec(projectList, planLater(child)) :: Nil
386386
case logical.Filter(condition, child) =>
387387
execution.FilterExec(condition, planLater(child)) :: Nil
388+
case f: logical.TypedFilter =>
389+
execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
388390
case e @ logical.Expand(_, _, child) =>
389391
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
390392
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>

0 commit comments

Comments
 (0)