diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java index 65d029dc309b5..8495aebc3eba4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java @@ -17,9 +17,16 @@ package org.apache.spark.sql.connector.read; +import java.util.ArrayList; +import java.util.List; + import org.apache.spark.annotation.Experimental; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.filter.Predicate; import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.internal.connector.PredicateUtils; + +import scala.Option; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface if they can @@ -30,7 +37,7 @@ * @since 3.2.0 */ @Experimental -public interface SupportsRuntimeFiltering extends Scan { +public interface SupportsRuntimeFiltering extends Scan, SupportsRuntimeV2Filtering { /** * Returns attributes this scan can be filtered by at runtime. *
@@ -57,4 +64,18 @@ public interface SupportsRuntimeFiltering extends Scan {
* @param filters data source filters used to filter the scan at runtime
*/
void filter(Filter[] filters);
+
+ default void filter(Predicate[] predicates) {
+ List
+ * Note that Spark will push runtime filters only if they are beneficial.
+ *
+ * @since 3.4.0
+ */
+@Experimental
+public interface SupportsRuntimeV2Filtering extends Scan {
+ /**
+ * Returns attributes this scan can be filtered by at runtime.
+ *
+ * Spark will call {@link #filter(Predicate[])} if it can derive a runtime
+ * predicate for any of the filter attributes.
+ */
+ NamedReference[] filterAttributes();
+
+ /**
+ * Filters this scan using runtime predicates.
+ *
+ * The provided expressions must be interpreted as a set of predicates that are ANDed together.
+ * Implementations may use the predicates to prune initially planned {@link InputPartition}s.
+ *
+ * If the scan also implements {@link SupportsReportPartitioning}, it must preserve
+ * the originally reported partitioning during runtime filtering. While applying runtime
+ * predicates, the scan may detect that some {@link InputPartition}s have no matching data. It
+ * can omit such partitions entirely only if it does not report a specific partitioning.
+ * Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no
+ * matching data with empty {@link InputPartition}s but must preserve the overall number of
+ * partitions.
+ *
+ * Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
+ *
+ * @param predicates data source V2 predicates used to filter the scan at runtime
+ */
+ void filter(Predicate[] predicates);
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index a8c877a29de8a..d662b83eaf015 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -68,6 +68,7 @@ object Literal {
case b: Byte => Literal(b, ByteType)
case s: Short => Literal(s, ShortType)
case s: String => Literal(UTF8String.fromString(s), StringType)
+ case s: UTF8String => Literal(s, StringType)
case c: Char => Literal(UTF8String.fromString(c.toString), StringType)
case ac: Array[Char] => Literal(UTF8String.fromString(String.valueOf(ac)), StringType)
case b: Boolean => Literal(b, BooleanType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
new file mode 100644
index 0000000000000..ace6b30d4ccec
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.sources.{Filter, In}
+
+private[sql] object PredicateUtils {
+
+ def toV1(predicate: Predicate): Option[Filter] = {
+ predicate.name() match {
+ // TODO: add conversion for other V2 Predicate
+ case "IN" if predicate.children()(0).isInstanceOf[NamedReference] =>
+ val attribute = predicate.children()(0).toString
+ val values = predicate.children().drop(1)
+ if (values.length > 0) {
+ if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None
+ val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType
+ if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) {
+ return None
+ }
+ val inValues = values.map(v =>
+ CatalystTypeConverters.convertToScala(v.asInstanceOf[LiteralValue[_]].value, dataType))
+ Some(In(attribute, inValues))
+ } else {
+ Some(In(attribute, Array.empty[Any]))
+ }
+
+ case _ => None
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index 995c57c062e8a..3255dee0a16b0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -268,12 +268,11 @@ class InMemoryTable(
case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics
- case class InMemoryBatchScan(
+ abstract class BatchScanBaseClass(
var data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
- extends Scan with Batch with SupportsRuntimeFiltering with SupportsReportStatistics
- with SupportsReportPartitioning {
+ extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning {
override def toBatch: Batch = this
@@ -308,6 +307,13 @@ class InMemoryTable(
val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name))
new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema)
}
+ }
+
+ case class InMemoryBatchScan(
+ var _data: Seq[InputPartition],
+ readSchema: StructType,
+ tableSchema: StructType)
+ extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering {
override def filterAttributes(): Array[NamedReference] = {
val scanFields = readSchema.fields.map(_.name).toSet
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
new file mode 100644
index 0000000000000..896c2919c1476
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class InMemoryTableWithV2Filter(
+ name: String,
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: util.Map[String, String])
+ extends InMemoryTable(name, schema, partitioning, properties) {
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new InMemoryV2FilterScanBuilder(schema)
+ }
+
+ class InMemoryV2FilterScanBuilder(tableSchema: StructType)
+ extends InMemoryScanBuilder(tableSchema) {
+ override def build: Scan =
+ InMemoryV2FilterBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema)
+ }
+
+ case class InMemoryV2FilterBatchScan(
+ var _data: Seq[InputPartition],
+ readSchema: StructType,
+ tableSchema: StructType)
+ extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeV2Filtering {
+
+ override def filterAttributes(): Array[NamedReference] = {
+ val scanFields = readSchema.fields.map(_.name).toSet
+ partitioning.flatMap(_.references)
+ .filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
+ }
+
+ override def filter(filters: Array[Predicate]): Unit = {
+ if (partitioning.length == 1 && partitioning.head.references().length == 1) {
+ val ref = partitioning.head.references().head
+ filters.foreach {
+ case p : Predicate if p.name().equals("IN") =>
+ if (p.children().length > 1) {
+ val filterRef = p.children()(0).asInstanceOf[FieldReference].references.head
+ if (filterRef.toString.equals(ref.toString)) {
+ val matchingKeys =
+ p.children().drop(1).map(_.asInstanceOf[LiteralValue[_]].value.toString).toSet
+ data = data.filter(partition => {
+ val key = partition.asInstanceOf[BufferedRows].keyString
+ matchingKeys.contains(key)
+ })
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala
new file mode 100644
index 0000000000000..08c1f65db290c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+
+class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog {
+ import CatalogV2Implicits._
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
+
+ val tableName = s"$name.${ident.quoted}"
+ val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties)
+ tables.put(ident, table)
+ namespaces.putIfAbsent(ident.namespace.toList, Map())
+ table
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 7f30300a39c17..c9e6dd9630466 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -44,7 +44,7 @@ import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
+import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.streaming.StreamingRelation
@@ -652,25 +652,6 @@ object DataSourceStrategy
}
}
- /**
- * Translates a runtime filter into a data source filter.
- *
- * Runtime filters usually contain a subquery that must be evaluated before the translation.
- * If the underlying subquery hasn't completed yet, this method will throw an exception.
- */
- protected[sql] def translateRuntimeFilter(expr: Expression): Option[Filter] = expr match {
- case in @ InSubqueryExec(e @ PushableColumnAndNestedColumn(name), _, _, _, _, _) =>
- val values = in.values().getOrElse {
- throw new IllegalStateException(s"Can't translate $in to source filter, no subquery result")
- }
- val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
- Some(sources.In(name, values.map(toScala)))
-
- case other =>
- logWarning(s"Can't translate $other to source filter, unsupported expression")
- None
- }
-
/**
* Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s
* and can be handled by `relation`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
index ba969eb6ff1a3..f1c43b8f60c96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala
@@ -27,8 +27,7 @@ import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.InternalRowSet
import org.apache.spark.sql.catalyst.util.truncatedString
-import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeFiltering}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, PartitionReaderFactory, Scan, SupportsRuntimeV2Filtering}
/**
* Physical plan node for scanning a batch of data from a data source v2.
@@ -56,7 +55,7 @@ case class BatchScanExec(
@transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
val dataSourceFilters = runtimeFilters.flatMap {
- case DynamicPruningExpression(e) => DataSourceStrategy.translateRuntimeFilter(e)
+ case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
case _ => None
}
@@ -64,7 +63,7 @@ case class BatchScanExec(
val originalPartitioning = outputPartitioning
// the cast is safe as runtime filters are only assigned if the scan can be filtered
- val filterableScan = scan.asInstanceOf[SupportsRuntimeFiltering]
+ val filterableScan = scan.asInstanceOf[SupportsRuntimeV2Filtering]
filterableScan.filter(dataSourceFilters.toArray)
// call toBatch again to get filtered partitions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 16c6b331d1093..907f3be102c9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._
import scala.collection.mutable
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
import org.apache.spark.sql.catalyst.catalog.CatalogUtils
@@ -31,14 +32,14 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{toPrettySQL, ResolveDefaultColumns, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDelete, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
-import org.apache.spark.sql.connector.expressions.FieldReference
+import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.connector.read.LocalScan
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.connector.write.V1Write
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
@@ -498,7 +499,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}
}
-private[sql] object DataSourceV2Strategy {
+private[sql] object DataSourceV2Strategy extends Logging {
private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = {
predicate match {
@@ -582,6 +583,25 @@ private[sql] object DataSourceV2Strategy {
throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate))
}
}
+
+ /**
+ * Translates a runtime filter into a data source v2 Predicate.
+ *
+ * Runtime filters usually contain a subquery that must be evaluated before the translation.
+ * If the underlying subquery hasn't completed yet, this method will throw an exception.
+ */
+ protected[sql] def translateRuntimeFilterV2(expr: Expression): Option[Predicate] = expr match {
+ case in @ InSubqueryExec(PushableColumnAndNestedColumn(name), _, _, _, _, _) =>
+ val values = in.values().getOrElse {
+ throw new IllegalStateException(s"Can't translate $in to v2 Predicate, no subquery result")
+ }
+ val literals = values.map(LiteralValue(_, in.child.dataType))
+ Some(new Predicate("IN", FieldReference(name) +: literals))
+
+ case other =>
+ logWarning(s"Can't translate $other to source filter, unsupported expression")
+ None
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
index 61a243ddb3368..60ecc4b635e57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.JoinSelectionHelper
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering
+import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
@@ -78,7 +78,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
} else {
None
}
- case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeFiltering, _, _, _)) =>
+ case (resExp, r @ DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) =>
val filterAttrs = V2ExpressionUtils.resolveRefs[Attribute](scan.filterAttributes, r)
if (resExp.references.subsetOf(AttributeSet(filterAttrs))) {
Some(r)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index 366120fb66c1a..fd213d120b6a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.GivenWhenThen
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
-import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog
+import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, InMemoryTableWithV2FilterCatalog}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -1805,3 +1805,21 @@ class DynamicPartitionPruningV2SuiteAEOff extends DynamicPartitionPruningV2Suite
class DynamicPartitionPruningV2SuiteAEOn extends DynamicPartitionPruningV2Suite
with EnableAdaptiveExecutionSuite
+
+abstract class DynamicPartitionPruningV2FilterSuite
+ extends DynamicPartitionPruningDataSourceSuiteBase {
+ override protected def runAnalyzeColumnCommands: Boolean = false
+
+ override protected def initState(): Unit = {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableWithV2FilterCatalog].getName)
+ spark.conf.set("spark.sql.defaultCatalog", "testcat")
+ }
+}
+
+class DynamicPartitionPruningV2FilterSuiteAEOff
+ extends DynamicPartitionPruningV2FilterSuite
+ with DisableAdaptiveExecutionSuite
+
+class DynamicPartitionPruningV2FilterSuiteAEOn
+ extends DynamicPartitionPruningV2FilterSuite
+ with EnableAdaptiveExecutionSuite