diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 763841efbd9f..034fe56fbe17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -238,6 +238,14 @@ private[parquet] object ParquetFilters { case sources.Not(pred) => createFilter(schema, pred).map(FilterApi.not) + case sources.In(name, values) if canMakeFilterOn(name) => + val conds = values.flatMap(v => makeEq.lift(nameToType(name)).map(_(name, v))) + var filter = conds(0) + conds.drop(1).foreach { v => + filter = FilterApi.or(filter, v) + } + Some(filter) + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 33801954ebd5..410b32d5bde2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.charset.StandardCharsets -import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -602,6 +601,40 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test(" Convert IN predicate to Parquet filter predicate") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false) + )) + + assertResult(Some( + FilterApi.eq(intColumn("a"), 10: Integer)) + ) { + ParquetFilters.createFilter( + schema, + sources.In("a", Array(10))) + } + + assertResult(Some(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer))) + ) { + ParquetFilters.createFilter( + schema, + sources.In("a", Array(10, 20))) + } + + assertResult(Some(or(or( + FilterApi.eq(intColumn("a"), 10: Integer), + FilterApi.eq(intColumn("a"), 20: Integer)), + FilterApi.eq(intColumn("a"), 30: Integer))) + ) { + ParquetFilters.createFilter( + schema, + sources.In("a", Array(10, 20, 30))) + } + } + } class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {