Skip to content

Commit ac06b0c

Browse files
committed
[SPARK-25306][SQL] Use cache to speed up createFilter
1 parent 6ad8d4c commit ac06b0c

File tree

4 files changed

+119
-25
lines changed

4 files changed

+119
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717

1818
package org.apache.spark.sql.execution.datasources.orc
1919

20+
import java.util.concurrent.TimeUnit
21+
22+
import com.google.common.cache.{CacheBuilder, CacheLoader}
2023
import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument, SearchArgumentFactory}
2124
import org.apache.orc.storage.ql.io.sarg.SearchArgument.Builder
2225
import org.apache.orc.storage.serde2.io.HiveDecimalWritable
2326

27+
import org.apache.spark.SparkEnv
2428
import org.apache.spark.sql.sources.Filter
2529
import org.apache.spark.sql.types._
2630

@@ -54,7 +58,37 @@ import org.apache.spark.sql.types._
5458
* builder methods mentioned above can only be found in test code, where all tested filters are
5559
* known to be convertible.
5660
*/
57-
private[orc] object OrcFilters {
61+
private[sql] object OrcFilters {
62+
63+
case class FilterWithTypeMap(filter: Filter, typeMap: Map[String, DataType])
64+
65+
private val defaultCacheExpireTimeout = TimeUnit.SECONDS.toSeconds(20)
66+
67+
lazy val cacheExpireTimeout: Long =
68+
Option(SparkEnv.get).map(_.conf.getTimeAsSeconds(
69+
"spark.sql.orc.cache.sarg.timeout",
70+
s"${defaultCacheExpireTimeout}s")).getOrElse(defaultCacheExpireTimeout)
71+
72+
private lazy val searchArgumentCache = CacheBuilder.newBuilder()
73+
.expireAfterAccess(cacheExpireTimeout, TimeUnit.SECONDS)
74+
.build(
75+
new CacheLoader[FilterWithTypeMap, Option[Builder]]() {
76+
override def load(typeMapAndFilter: FilterWithTypeMap): Option[Builder] = {
77+
buildSearchArgument(
78+
typeMapAndFilter.typeMap, typeMapAndFilter.filter, SearchArgumentFactory.newBuilder())
79+
}
80+
})
81+
82+
private def getOrBuildSearchArgumentWithNewBuilder(
83+
dataTypeMap: Map[String, DataType],
84+
expression: Filter): Option[Builder] = {
85+
// When `spark.sql.orc.cache.sarg.timeout` is 0, cache is disabled.
86+
if (cacheExpireTimeout > 0) {
87+
searchArgumentCache.get(FilterWithTypeMap(expression, dataTypeMap))
88+
} else {
89+
buildSearchArgument(dataTypeMap, expression, SearchArgumentFactory.newBuilder())
90+
}
91+
}
5892

5993
/**
6094
* Create ORC filter as a SearchArgument instance.
@@ -66,12 +100,19 @@ private[orc] object OrcFilters {
66100
// collect all convertible ones to build the final `SearchArgument`.
67101
val convertibleFilters = for {
68102
filter <- filters
69-
_ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder())
103+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, filter)
70104
} yield filter
71105

72106
for {
73107
// Combines all convertible filters using `And` to produce a single conjunction
74-
conjunction <- convertibleFilters.reduceOption(org.apache.spark.sql.sources.And)
108+
conjunction <- convertibleFilters.reduceOption { (x, y) =>
109+
val newFilter = org.apache.spark.sql.sources.And(x, y)
110+
if (cacheExpireTimeout > 0) {
111+
// Build in a bottom-up manner
112+
getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, newFilter)
113+
}
114+
newFilter
115+
}
75116
// Then tries to build a single ORC `SearchArgument` for the conjunction predicate
76117
builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder())
77118
} yield builder.build()
@@ -127,8 +168,6 @@ private[orc] object OrcFilters {
127168
dataTypeMap: Map[String, DataType],
128169
expression: Filter,
129170
builder: Builder): Option[Builder] = {
130-
def newBuilder = SearchArgumentFactory.newBuilder()
131-
132171
def getType(attribute: String): PredicateLeaf.Type =
133172
getPredicateLeafType(dataTypeMap(attribute))
134173

@@ -144,23 +183,23 @@ private[orc] object OrcFilters {
144183
// Pushing one side of AND down is only safe to do at the top level.
145184
// You can see ParquetRelation's initializeLocalJobFunc method as an example.
146185
for {
147-
_ <- buildSearchArgument(dataTypeMap, left, newBuilder)
148-
_ <- buildSearchArgument(dataTypeMap, right, newBuilder)
186+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, left)
187+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, right)
149188
lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd())
150189
rhs <- buildSearchArgument(dataTypeMap, right, lhs)
151190
} yield rhs.end()
152191

153192
case Or(left, right) =>
154193
for {
155-
_ <- buildSearchArgument(dataTypeMap, left, newBuilder)
156-
_ <- buildSearchArgument(dataTypeMap, right, newBuilder)
194+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, left)
195+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, right)
157196
lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr())
158197
rhs <- buildSearchArgument(dataTypeMap, right, lhs)
159198
} yield rhs.end()
160199

161200
case Not(child) =>
162201
for {
163-
_ <- buildSearchArgument(dataTypeMap, child, newBuilder)
202+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, child)
164203
negate <- buildSearchArgument(dataTypeMap, child, builder.startNot())
165204
} yield negate.end()
166205

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.datasources.orc
2020
import java.nio.charset.StandardCharsets
2121
import java.sql.{Date, Timestamp}
2222

23-
import scala.collection.JavaConverters._
24-
2523
import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
24+
import org.scalatest.concurrent.TimeLimits
25+
import org.scalatest.time.SpanSugar._
26+
import scala.collection.JavaConverters._
2627

2728
import org.apache.spark.sql.{Column, DataFrame}
2829
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -39,7 +40,7 @@ import org.apache.spark.sql.types._
3940
* - OrcFilterSuite uses 'org.apache.orc.storage.ql.io.sarg' package.
4041
* - HiveOrcFilterSuite uses 'org.apache.hadoop.hive.ql.io.sarg' package.
4142
*/
42-
class OrcFilterSuite extends OrcTest with SharedSQLContext {
43+
class OrcFilterSuite extends OrcTest with SharedSQLContext with TimeLimits {
4344

4445
private def checkFilterPredicate(
4546
df: DataFrame,
@@ -383,4 +384,13 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext {
383384
)).get.toString
384385
}
385386
}
387+
388+
test("createFilter should not hang") {
389+
import org.apache.spark.sql.sources._
390+
val schema = new StructType(Array(StructField("a", IntegerType, nullable = true)))
391+
val filters = (1 to 500).map(LessThan("a", _)).toArray[Filter]
392+
failAfter(2 seconds) {
393+
OrcFilters.createFilter(schema, filters)
394+
}
395+
}
386396
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717

1818
package org.apache.spark.sql.hive.orc
1919

20+
import java.util.concurrent.TimeUnit
21+
22+
import com.google.common.cache.{CacheBuilder, CacheLoader}
2023
import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory}
2124
import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder
2225

26+
import org.apache.spark.SparkEnv
2327
import org.apache.spark.internal.Logging
2428
import org.apache.spark.sql.sources._
2529
import org.apache.spark.sql.types._
@@ -55,19 +59,52 @@ import org.apache.spark.sql.types._
5559
* known to be convertible.
5660
*/
5761
private[orc] object OrcFilters extends Logging {
62+
case class FilterWithTypeMap(filter: Filter, typeMap: Map[String, DataType])
63+
64+
private lazy val cacheExpireTimeout =
65+
org.apache.spark.sql.execution.datasources.orc.OrcFilters.cacheExpireTimeout
66+
67+
private lazy val searchArgumentCache = CacheBuilder.newBuilder()
68+
.expireAfterAccess(cacheExpireTimeout, TimeUnit.SECONDS)
69+
.build(
70+
new CacheLoader[FilterWithTypeMap, Option[Builder]]() {
71+
override def load(typeMapAndFilter: FilterWithTypeMap): Option[Builder] = {
72+
buildSearchArgument(
73+
typeMapAndFilter.typeMap, typeMapAndFilter.filter, SearchArgumentFactory.newBuilder())
74+
}
75+
})
76+
77+
private def getOrBuildSearchArgumentWithNewBuilder(
78+
dataTypeMap: Map[String, DataType],
79+
expression: Filter): Option[Builder] = {
80+
// When `spark.sql.orc.cache.sarg.timeout` is 0, cache is disabled.
81+
if (cacheExpireTimeout > 0) {
82+
searchArgumentCache.get(FilterWithTypeMap(expression, dataTypeMap))
83+
} else {
84+
buildSearchArgument(dataTypeMap, expression, SearchArgumentFactory.newBuilder())
85+
}
86+
}
87+
5888
def createFilter(schema: StructType, filters: Array[Filter]): Option[SearchArgument] = {
5989
val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap
6090

6191
// First, tries to convert each filter individually to see whether it's convertible, and then
6292
// collect all convertible ones to build the final `SearchArgument`.
6393
val convertibleFilters = for {
6494
filter <- filters
65-
_ <- buildSearchArgument(dataTypeMap, filter, SearchArgumentFactory.newBuilder())
95+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, filter)
6696
} yield filter
6797

6898
for {
6999
// Combines all convertible filters using `And` to produce a single conjunction
70-
conjunction <- convertibleFilters.reduceOption(And)
100+
conjunction <- convertibleFilters.reduceOption { (x, y) =>
101+
val newFilter = org.apache.spark.sql.sources.And(x, y)
102+
if (cacheExpireTimeout > 0) {
103+
// Build in a bottom-up manner
104+
getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, newFilter)
105+
}
106+
newFilter
107+
}
71108
// Then tries to build a single ORC `SearchArgument` for the conjunction predicate
72109
builder <- buildSearchArgument(dataTypeMap, conjunction, SearchArgumentFactory.newBuilder())
73110
} yield builder.build()
@@ -77,8 +114,6 @@ private[orc] object OrcFilters extends Logging {
77114
dataTypeMap: Map[String, DataType],
78115
expression: Filter,
79116
builder: Builder): Option[Builder] = {
80-
def newBuilder = SearchArgumentFactory.newBuilder()
81-
82117
def isSearchableType(dataType: DataType): Boolean = dataType match {
83118
// Only the values in the Spark types below can be recognized by
84119
// the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method.
@@ -98,23 +133,23 @@ private[orc] object OrcFilters extends Logging {
98133
// Pushing one side of AND down is only safe to do at the top level.
99134
// You can see ParquetRelation's initializeLocalJobFunc method as an example.
100135
for {
101-
_ <- buildSearchArgument(dataTypeMap, left, newBuilder)
102-
_ <- buildSearchArgument(dataTypeMap, right, newBuilder)
136+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, left)
137+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, right)
103138
lhs <- buildSearchArgument(dataTypeMap, left, builder.startAnd())
104139
rhs <- buildSearchArgument(dataTypeMap, right, lhs)
105140
} yield rhs.end()
106141

107142
case Or(left, right) =>
108143
for {
109-
_ <- buildSearchArgument(dataTypeMap, left, newBuilder)
110-
_ <- buildSearchArgument(dataTypeMap, right, newBuilder)
144+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, left)
145+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, right)
111146
lhs <- buildSearchArgument(dataTypeMap, left, builder.startOr())
112147
rhs <- buildSearchArgument(dataTypeMap, right, lhs)
113148
} yield rhs.end()
114149

115150
case Not(child) =>
116151
for {
117-
_ <- buildSearchArgument(dataTypeMap, child, newBuilder)
152+
_ <- getOrBuildSearchArgumentWithNewBuilder(dataTypeMap, child)
118153
negate <- buildSearchArgument(dataTypeMap, child, builder.startNot())
119154
} yield negate.end()
120155

sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ package org.apache.spark.sql.hive.orc
2020
import java.nio.charset.StandardCharsets
2121
import java.sql.{Date, Timestamp}
2222

23-
import scala.collection.JavaConverters._
24-
2523
import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument}
24+
import org.scalatest.concurrent.TimeLimits
25+
import org.scalatest.time.SpanSugar._
26+
import scala.collection.JavaConverters._
2627

2728
import org.apache.spark.sql.{Column, DataFrame}
2829
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -36,7 +37,7 @@ import org.apache.spark.sql.types._
3637
/**
3738
* A test suite that tests Hive ORC filter API based filter pushdown optimization.
3839
*/
39-
class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton {
40+
class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton with TimeLimits {
4041

4142
override val orcImp: String = "hive"
4243

@@ -384,4 +385,13 @@ class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton {
384385
)).get.toString
385386
}
386387
}
388+
389+
test("createFilter should not hang") {
390+
import org.apache.spark.sql.sources._
391+
val schema = new StructType(Array(StructField("a", IntegerType, nullable = true)))
392+
val filters = (1 to 500).map(LessThan("a", _)).toArray[Filter]
393+
failAfter(2 seconds) {
394+
OrcFilters.createFilter(schema, filters)
395+
}
396+
}
387397
}

0 commit comments

Comments
 (0)