@@ -21,14 +21,14 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpressi
2121import org .apache .spark .sql .catalyst .planning .ScanOperation
2222import org .apache .spark .sql .catalyst .plans .logical .{Filter , LogicalPlan , Project }
2323import org .apache .spark .sql .catalyst .rules .Rule
24- import org .apache .spark .sql .catalyst . trees . TreeNodeTag
24+ import org .apache .spark .sql .connector . read .{ Scan , V1Scan }
2525import org .apache .spark .sql .execution .datasources .DataSourceStrategy
26+ import org .apache .spark .sql .sources
27+ import org .apache .spark .sql .types .StructType
2628
2729object V2ScanRelationPushDown extends Rule [LogicalPlan ] {
2830 import DataSourceV2Implicits ._
2931
30- val PUSHED_FILTERS_TAG = TreeNodeTag [Seq [org.apache.spark.sql.sources.Filter ]](" pushed_filters" )
31-
3232 override def apply (plan : LogicalPlan ): LogicalPlan = plan transformDown {
3333 case ScanOperation (project, filters, relation : DataSourceV2Relation ) =>
3434 val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options)
@@ -57,8 +57,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
5757 |Output: ${output.mkString(" , " )}
5858 """ .stripMargin)
5959
60- val scanRelation = DataSourceV2ScanRelation (relation.table, scan, output)
61- scanRelation.setTagValue(PUSHED_FILTERS_TAG , pushedFilters)
60+ val wrappedScan = scan match {
61+ case v1 : V1Scan =>
62+ val translated = filters.flatMap(DataSourceStrategy .translateFilter)
63+ V1ScanWrapper (v1, translated, pushedFilters)
64+ case _ => scan
65+ }
66+
67+ val scanRelation = DataSourceV2ScanRelation (relation.table, wrappedScan, output)
6268
6369 val projectionOverSchema = ProjectionOverSchema (output.toStructType)
6470 val projectionFunc = (expr : Expression ) => expr transformDown {
@@ -81,3 +87,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
8187 withProjection
8288 }
8389}
90+
91+ // A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
92+ // the physical v1 scan node.
93+ case class V1ScanWrapper (
94+ v1Scan : V1Scan ,
95+ translatedFilters : Seq [sources.Filter ],
96+ handledFilters : Seq [sources.Filter ]) extends Scan {
97+ override def readSchema (): StructType = v1Scan.readSchema()
98+ }
0 commit comments