diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java index 65d029dc309b5..8495aebc3eba4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java @@ -17,9 +17,16 @@ package org.apache.spark.sql.connector.read; +import java.util.ArrayList; +import java.util.List; + import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.internal.connector.PredicateUtils; + +import scala.Option; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can @@ -30,7 +37,7 @@ * @since 3.2.0 */ @Experimental -public interface SupportsRuntimeFiltering extends Scan { +public interface SupportsRuntimeFiltering extends Scan, SupportsRuntimeV2Filtering { /** * Returns attributes this scan can be filtered by at runtime. *

@@ -57,4 +64,18 @@ public interface SupportsRuntimeFiltering extends Scan { * @param filters data source filters used to filter the scan at runtime */ void filter(Filter[] filters); + + default void filter(Predicate[] predicates) { + List filterList = new ArrayList(); + + for (int i = 0; i < predicates.length; i++) { + Option filter = PredicateUtils.toV1(predicates[i]); + if (filter.nonEmpty()) { + filterList.add((Filter)filter.get()); + } + } + + Filter[] filters = new Filter[filterList.size()]; + this.filter(filterList.toArray(filters)); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java new file mode 100644 index 0000000000000..7c238bde969b2 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.sources.Filter; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can + * filter initially planned {@link InputPartition}s using predicates Spark infers at runtime. + * This interface is very similar to {@link SupportsRuntimeFiltering} except it uses + * data source V2 {@link Predicate} instead of data source V1 {@link Filter}. + * {@link SupportsRuntimeV2Filtering} is preferred over {@link SupportsRuntimeFiltering} + * and only one of them should be implemented by the data sources. + * + *

+ * Note that Spark will push runtime filters only if they are beneficial. + * + * @since 3.4.0 + */ +@Experimental +public interface SupportsRuntimeV2Filtering extends Scan { + /** + * Returns attributes this scan can be filtered by at runtime. + *

+ * Spark will call {@link #filter(Predicate[])} if it can derive a runtime + * predicate for any of the filter attributes. + */ + NamedReference[] filterAttributes(); + + /** + * Filters this scan using runtime predicates. + *

+ * The provided expressions must be interpreted as a set of predicates that are ANDed together. + * Implementations may use the predicates to prune initially planned {@link InputPartition}s. + *

+ * If the scan also implements {@link SupportsReportPartitioning}, it must preserve + * the originally reported partitioning during runtime filtering. While applying runtime + * predicates, the scan may detect that some {@link InputPartition}s have no matching data. It + * can omit such partitions entirely only if it does not report a specific partitioning. + * Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no + * matching data with empty {@link InputPartition}s but must preserve the overall number of + * partitions. + *

+ * Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime. + * + * @param predicates data source V2 predicates used to filter the scan at runtime + */ + void filter(Predicate[] predicates); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index a8c877a29de8a..d662b83eaf015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -68,6 +68,7 @@ object Literal { case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) + case s: UTF8String => Literal(s, StringType) case c: Char => Literal(UTF8String.fromString(c.toString), StringType) case ac: Array[Char] => Literal(UTF8String.fromString(String.valueOf(ac)), StringType) case b: Boolean => Literal(b, BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala new file mode 100644 index 0000000000000..ace6b30d4ccec --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.sources.{Filter, In} + +private[sql] object PredicateUtils { + + def toV1(predicate: Predicate): Option[Filter] = { + predicate.name() match { + // TODO: add conversion for other V2 Predicate + case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val values = predicate.children().drop(1) + if (values.length > 0) { + if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None + val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType + if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) { + return None + } + val inValues = values.map(v => + CatalystTypeConverters.convertToScala(v.asInstanceOf[LiteralValue[_]].value, dataType)) + Some(In(attribute, inValues)) + } else { + Some(In(attribute, Array.empty[Any])) + } + + case _ => None + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 995c57c062e8a..3255dee0a16b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -268,12 +268,11 @@ class InMemoryTable( case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics - case class InMemoryBatchScan( + abstract class BatchScanBaseClass( var data: Seq[InputPartition], readSchema: StructType, tableSchema: StructType) - extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics - with SupportsReportPartitioning { + extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning { override def toBatch: Batch = this @@ -308,6 +307,13 @@ class InMemoryTable( val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name)) new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema) } + } + + case class InMemoryBatchScan( + var _data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType) + extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering { override def filterAttributes(): Array[NamedReference] = { val scanFields = readSchema.fields.map(_.name).toSet diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala new file mode 100644 index 0000000000000..896c2919c1476 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class InMemoryTableWithV2Filter( + name: String, + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]) + extends InMemoryTable(name, schema, partitioning, properties) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryV2FilterScanBuilder(schema) + } + + class InMemoryV2FilterScanBuilder(tableSchema: StructType) + extends InMemoryScanBuilder(tableSchema) { + override def build: Scan = + InMemoryV2FilterBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema) + } + + case class InMemoryV2FilterBatchScan( + var _data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType) + extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering { + + override def filterAttributes(): Array[NamedReference] = { + val scanFields = readSchema.fields.map(_.name).toSet + partitioning.flatMap(_.references) + .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) + } + + override def filter(filters: Array[Predicate]): Unit = { + if (partitioning.length == 1 && partitioning.head.references().length == 1) { + val ref = partitioning.head.references().head + filters.foreach { + case p : Predicate if p.name().equals("IN") => + if (p.children().length > 1) { + val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head + if (filterRef.toString.equals(ref.toString)) { + val matchingKeys = + p.children().drop(1).map(_.asInstanceOf[LiteralValue[_]].value.toString).toSet + data = data.filter(partition => { + val key = partition.asInstanceOf[BufferedRows].keyString + matchingKeys.contains(key) + }) + } + } + } + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala new file mode 100644 index 0000000000000..08c1f65db290c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.util + +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType + +class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog { + import CatalogV2Implicits._ + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val tableName = s"$name.${ident.quoted}" + val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7f30300a39c17..c9e6dd9630466 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation @@ -652,25 +652,6 @@ object DataSourceStrategy } } - /** - * Translates a runtime filter into a data source filter. - * - * Runtime filters usually contain a subquery that must be evaluated before the translation. - * If the underlying subquery hasn't completed yet, this method will throw an exception. - */ - protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match { - case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _, _, _) => - val values = in.values().getOrElse { - throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result") - } - val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) - Some(sources.In(name, values.map(toScala))) - - case other => - logWarning(s"Can't translate $other to source filter, unsupported expression") - None - } - /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s * and can be handled by `relation`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index ba969eb6ff1a3..f1c43b8f60c96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.InternalRowSet import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeV2Filtering} /** * Physical plan node for scanning a batch of data from a data source v2. @@ -56,7 +55,7 @@ case class BatchScanExec( @transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = { val dataSourceFilters = runtimeFilters.flatMap { - case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e) + case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e) case _ => None } @@ -64,7 +63,7 @@ case class BatchScanExec( val originalPartitioning = outputPartitioning // the cast is safe as runtime filters are only assigned if the scan can be filtered - val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering] + val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering] filterableScan.filter(dataSourceFilters.toArray) // call toBatch again to get filtered partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 16c6b331d1093..907f3be102c9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.catalog.CatalogUtils @@ -31,14 +32,14 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDelete, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.sources.{BaseRelation, TableScan} @@ -498,7 +499,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } -private[sql] object DataSourceV2Strategy { +private[sql] object DataSourceV2Strategy extends Logging { private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = { predicate match { @@ -582,6 +583,25 @@ private[sql] object DataSourceV2Strategy { throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate)) } } + + /** + * Translates a runtime filter into a data source v2 Predicate. + * + * Runtime filters usually contain a subquery that must be evaluated before the translation. + * If the underlying subquery hasn't completed yet, this method will throw an exception. + */ + protected[sql] def translateRuntimeFilterV2(expr: Expression): Option[Predicate] = expr match { + case in @ InSubqueryExec(PushableColumnAndNestedColumn(name), _, _, _, _, _) => + val values = in.values().getOrElse { + throw new IllegalStateException(s"Can't translate $in to v2 Predicate, no subquery result") + } + val literals = values.map(LiteralValue(_, in.child.dataType)) + Some(new Predicate("IN", FieldReference(name) +: literals)) + + case other => + logWarning(s"Can't translate $other to source filter, unsupported expression") + None + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index 61a243ddb3368..60ecc4b635e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation @@ -78,7 +78,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join } else { None } - case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _)) => + case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) => val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r) if (resExp.references.subsetOf(AttributeSet(filterAttrs))) { Some(r) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index 366120fb66c1a..fd213d120b6a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.GivenWhenThen import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.plans.ExistenceJoin -import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec @@ -1805,3 +1805,21 @@ class DynamicPartitionPruningV2SuiteAEOff extends DynamicPartitionPruningV2Suite class DynamicPartitionPruningV2SuiteAEOn extends DynamicPartitionPruningV2Suite with EnableAdaptiveExecutionSuite + +abstract class DynamicPartitionPruningV2FilterSuite + extends DynamicPartitionPruningDataSourceSuiteBase { + override protected def runAnalyzeColumnCommands: Boolean = false + + override protected def initState(): Unit = { + spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableWithV2FilterCatalog].getName) + spark.conf.set("spark.sql.defaultCatalog", "testcat") + } +} + +class DynamicPartitionPruningV2FilterSuiteAEOff + extends DynamicPartitionPruningV2FilterSuite + with DisableAdaptiveExecutionSuite + +class DynamicPartitionPruningV2FilterSuiteAEOn + extends DynamicPartitionPruningV2FilterSuite + with EnableAdaptiveExecutionSuite