From 9d2923a6bc86ca7975ae383e16922510c42134ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 16 Oct 2019 23:49:10 +0800 Subject: [PATCH 1/9] add v1 read fallback API in DS v2 --- .../sql/connector/InMemoryTableCatalog.scala | 4 +- .../spark/sql/connector/read/V1Scan.scala | 43 ++++ .../sql/execution/DataSourceScanExec.scala | 8 +- .../datasources/DataSourceStrategy.scala | 33 +-- .../datasources/v2/DataSourceV2Strategy.scala | 52 ++++- .../datasources/v2/PushDownUtils.scala | 10 +- .../v2/V2ScanRelationPushDown.scala | 4 + .../sql/connector/V1ReadFallbackSuite.scala | 194 ++++++++++++++++++ 8 files changed, 312 insertions(+), 36 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala index 57c83ec68a64..6824efd9880a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala @@ -34,8 +34,8 @@ class BasicInMemoryTableCatalog extends TableCatalog { protected val namespaces: util.Map[List[String], Map[String, String]] = new ConcurrentHashMap[List[String], Map[String, String]]() - protected val tables: util.Map[Identifier, InMemoryTable] = - new ConcurrentHashMap[Identifier, InMemoryTable]() + protected val tables: util.Map[Identifier, Table] = + new ConcurrentHashMap[Identifier, Table]() private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala b/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala new file mode 100644 index 000000000000..7cda82ada0d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read + +import org.apache.spark.annotation.{Experimental, Unstable} +import org.apache.spark.sql.sources.BaseRelation + +/** + * A trait that should be implemented by V1 DataSources that would like to leverage the DataSource + * V2 read code paths. + * + * This interface is designed to provide Spark DataSources time to migrate to DataSource V2 and + * will be removed in a future Spark release. + * + * @since 3.0.0 + */ +@Experimental +@Unstable +trait V1Scan extends Scan { + + /** + * Creates an `BaseRelation` that can scan data from DataSource v1 to RDD[Row]. The returned + * relation must be a `TableScan` instance. + * + * @since 3.0.0 + */ + def toV1Relation(): BaseRelation +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 0d759085a7e2..37be10e13e4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -97,8 +97,7 @@ trait DataSourceScanExec extends LeafExecNode { /** Physical plan node for scanning data from a relation. */ case class RowDataSourceScanExec( - fullOutput: Seq[Attribute], - requiredColumnsIndex: Seq[Int], + output: Seq[Attribute], filters: Set[Filter], handledFilters: Set[Filter], rdd: RDD[InternalRow], @@ -106,8 +105,6 @@ case class RowDataSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with InputRDDCodegen { - def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput) - override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -141,7 +138,8 @@ case class RowDataSourceScanExec( // Don't care about `rdd` and `tableIdentifier` when canonicalizing. override def doCanonicalize(): SparkPlan = copy( - fullOutput.map(QueryPlan.normalizeExpressions(_, fullOutput)), + // Only the required column names matter when checking equality. + output.map(a => a.withExprId(ExprId(-1))), rdd = null, tableIdentifier = None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d44cb11e2876..aada8cb73ccf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -296,7 +296,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case l @ LogicalRelation(baseRelation: TableScan, _, _, _) => RowDataSourceScanExec( l.output, - l.output.indices, Set.empty, Set.empty, toCatalystRDD(l, baseRelation.buildScan()), @@ -368,8 +367,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with .map(relation.attributeMap) val scan = RowDataSourceScanExec( - relation.output, - requestedColumns.map(relation.output.indexOf), + requestedColumns, pushedFilters.toSet, handledFilters, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), @@ -390,8 +388,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq val scan = RowDataSourceScanExec( - relation.output, - requestedColumns.map(relation.output.indexOf), + requestedColumns, pushedFilters.toSet, handledFilters, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), @@ -409,14 +406,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with relation: LogicalRelation, output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { - if (relation.relation.needConversion) { - val converters = RowEncoder(StructType.fromAttributes(output)) - rdd.mapPartitions { iterator => - iterator.map(converters.toRow) - } - } else { - rdd.asInstanceOf[RDD[InternalRow]] - } + DataSourceStrategy.toCatalystRDD(relation.relation, output, rdd) } /** @@ -624,4 +614,21 @@ object DataSourceStrategy { (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } + + /** + * Convert RDD of Row into RDD of InternalRow with objects in catalyst types + */ + private[sql] def toCatalystRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): RDD[InternalRow] = { + if (relation.needConversion) { + val converters = RowEncoder(StructType.fromAttributes(output)) + rdd.mapPartitions { iterator => + iterator.map(converters.toRow) + } + } else { + rdd.asInstanceOf[RDD[InternalRow]] + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b452b66e0381..58009a9ef675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -19,34 +19,68 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable} import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, SupportsNamespaces, TableCapability, TableCatalog, TableChange} +import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources.TableScan import org.apache.spark.sql.util.CaseInsensitiveStringMap -object DataSourceV2Strategy extends Strategy with PredicateHelper { +class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { import DataSourceV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // projection and filters were already pushed down in the optimizer. + // this uses PhysicalOperation to get the projection and ensure that if the batch scan does + // not support columnar, a projection is added to convert the rows to UnsafeRow. case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => - // projection and filters were already pushed down in the optimizer. - // this uses PhysicalOperation to get the projection and ensure that if the batch scan does - // not support columnar, a projection is added to convert the rows to UnsafeRow. - val batchExec = BatchScanExec(relation.output, relation.scan) + val output = relation.output + val pushedFilters = relation.getTagValue(V2ScanRelationPushDown.PUSHED_FILTERS_TAG) + .getOrElse(Array.empty) + + val (scanExec, needsUnsafeConversion) = relation.scan match { + case v1Scan: V1Scan => + val v1Relation = v1Scan.toV1Relation() + if (v1Relation.schema != v1Scan.readSchema()) { + throw new IllegalArgumentException( + "The fallback v1 relation reports inconsistent schema:\n" + + "Schema of v2 scan: " + v1Scan.readSchema() + "\n" + + "Schema of v1 relation: " + v1Relation.schema) + } + val rdd = v1Relation match { + case s: TableScan => s.buildScan() + case _ => + throw new IllegalArgumentException( + "`V1Scan.toV1Relation` must return a `TableScan` instance.") + } + val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) + val dsScan = RowDataSourceScanExec( + output, + pushedFilters.toSet, + pushedFilters.toSet, + unsafeRowRDD, + v1Relation, + tableIdentifier = None) + (dsScan, false) + case _ => + val batchScan = BatchScanExec(output, relation.scan) + (batchScan, !batchScan.supportsColumnar) + } + val filterCondition = filters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, batchExec)).getOrElse(batchExec) + val withFilter = filterCondition.map(FilterExec(_, scanExec)).getOrElse(scanExec) - val withProjection = if (withFilter.output != project || !batchExec.supportsColumnar) { + val withProjection = if (withFilter.output != project || needsUnsafeConversion) { ProjectExec(project, withFilter) } else { withFilter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 09a8a7ebb6dd..33338b06565c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -34,7 +34,7 @@ object PushDownUtils extends PredicateHelper { */ def pushFilters( scanBuilder: ScanBuilder, - filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + filters: Seq[Expression]): (Seq[sources.Filter], Seq[Expression]) = { scanBuilder match { case r: SupportsPushDownFilters => // A map from translated data source leaf node filters to original catalyst filter @@ -62,11 +62,7 @@ object PushDownUtils extends PredicateHelper { val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } - // The filters which are marked as pushed to this data source - val pushedFilters = r.pushedFilters().map { filter => - DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) - } - (pushedFilters, untranslatableExprs ++ postScanFilters) + (r.pushedFilters(), untranslatableExprs ++ postScanFilters) case _ => (Nil, filters) } @@ -75,7 +71,7 @@ object PushDownUtils extends PredicateHelper { /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * - * @return the created `ScanConfig`(since column pruning is the last step of operator pushdown), + * @return the `Scan` instance (since column pruning is the last step of operator pushdown), * and new output attributes after column pruning. */ def pruneColumns( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 92acd3ba8d90..c72f0edd7d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -21,11 +21,14 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpressi import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.datasources.DataSourceStrategy object V2ScanRelationPushDown extends Rule[LogicalPlan] { import DataSourceV2Implicits._ + val PUSHED_FILTERS_TAG = TreeNodeTag[Array[org.apache.spark.sql.sources.Filter]]("pushed_filters") + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) @@ -55,6 +58,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { """.stripMargin) val scanRelation = DataSourceV2ScanRelation(relation.table, scan, output) + scanRelation.setTagValue(PUSHED_FILTERS_TAG, pushedFilters) val projectionOverSchema = ProjectionOverSchema(output.toStructType) val projectionFunc = (expr: Expression) => expr transformDown { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala new file mode 100644 index 000000000000..1da47c7a3892 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.connector.catalog.{Identifier, SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns, V1Scan} +import org.apache.spark.sql.execution.RowDataSourceScanExec +import org.apache.spark.sql.sources.{BaseRelation, Filter, GreaterThan, TableScan} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +abstract class V1ReadFallbackSuite extends QueryTest with SharedSparkSession { + protected def baseTableScan(): DataFrame + + test("full scan") { + val df = baseTableScan() + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + checkAnswer(df, Seq(Row(1, 10), Row(2, 20), Row(3, 30))) + } + + test("column pruning") { + val df = baseTableScan().select("i") + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + assert(v1Scan.head.output.map(_.name) == Seq("i")) + checkAnswer(df, Seq(Row(1), Row(2), Row(3))) + } + + test("filter push down") { + val df = baseTableScan().filter("i > 1 and j < 30") + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + // `j < 30` can't be pushed. + assert(v1Scan.head.handledFilters.size == 1) + checkAnswer(df, Seq(Row(2, 20))) + } + + test("filter push down + column pruning") { + val df = baseTableScan().filter("i > 1").select("i") + val v1Scan = df.queryExecution.executedPlan.collect { + case s: RowDataSourceScanExec => s + } + assert(v1Scan.length == 1) + assert(v1Scan.head.output.map(_.name) == Seq("i")) + assert(v1Scan.head.handledFilters.size == 1) + checkAnswer(df, Seq(Row(2), Row(3))) + } +} + +class V1ReadFallbackWithDataFrameReaderSuite extends V1ReadFallbackSuite { + override protected def baseTableScan(): DataFrame = { + spark.read.format(classOf[V1ReadFallbackTableProvider].getName).load() + } +} + +class V1ReadFallbackWithCatalogSuite extends V1ReadFallbackSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.sql.catalog.read_fallback", classOf[V1ReadFallbackCatalog].getName) + sql("CREATE TABLE read_fallback.tbl(i int, j int) USING foo") + } + + override def afterAll(): Unit = { + spark.conf.unset("spark.sql.catalog.read_fallback") + super.afterAll() + } + + override protected def baseTableScan(): DataFrame = { + spark.table("read_fallback.tbl") + } +} + +class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + // To simplify the test implementation, only support fixed schema. + if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { + throw new UnsupportedOperationException + } + val table = new TableWithV1ReadFallback + tables.put(ident, table) + table + } +} + +object V1ReadFallbackCatalog { + val schema = new StructType().add("i", "int").add("j", "int") +} + +class V1ReadFallbackTableProvider extends TableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + new TableWithV1ReadFallback + } +} + +class TableWithV1ReadFallback extends Table with SupportsRead { + override def name(): String = "v1-read-fallback" + + override def schema(): StructType = V1ReadFallbackCatalog.schema + + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.BATCH_READ).asJava + } + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new V1ReadFallbackScanBuilder + } + + private class V1ReadFallbackScanBuilder extends ScanBuilder + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + + private var requiredSchema: StructType = schema() + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + private var filters: Array[Filter] = Array.empty + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported + } + override def pushedFilters(): Array[Filter] = filters + + override def build(): Scan = new V1ReadFallbackScan(requiredSchema, filters) + } + + private class V1ReadFallbackScan( + requiredSchema: StructType, + filters: Array[Filter]) extends V1Scan { + override def readSchema(): StructType = requiredSchema + override def toV1Relation(): BaseRelation = { + new BaseRelation with TableScan { + override def sqlContext: SQLContext = SparkSession.active.sqlContext + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val lowerBound = if (filters.isEmpty) { + 0 + } else { + filters.collect { case GreaterThan("i", v: Int) => v }.max + } + val data = Seq(Row(1, 10), Row(2, 20), Row(3, 30)).filter(_.getInt(0) > lowerBound) + val result = if (requiredSchema.length == 2) { + data + } else if (requiredSchema.map(_.name) == Seq("i")) { + data.map(row => Row(row.getInt(0))) + } else if (requiredSchema.map(_.name) == Seq("j")) { + data.map(row => Row(row.getInt(1))) + } else { + throw new UnsupportedOperationException + } + + SparkSession.active.sparkContext.makeRDD(result) + } + } + } + } +} From b9976b7d0ab7538fe6456d5f70727827abdae451 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Oct 2019 20:56:33 +0800 Subject: [PATCH 2/9] address comments --- .../org/apache/spark/sql/connector/read/V1Scan.scala | 3 ++- .../spark/sql/execution/DataSourceScanExec.scala | 8 +++++--- .../org/apache/spark/sql/execution/SparkPlanner.scala | 5 ++--- .../apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/datasources/DataSourceStrategy.scala | 7 +++++-- .../datasources/v2/DataSourceV2Strategy.scala | 5 ++++- .../execution/streaming/IncrementalExecution.scala | 2 +- .../spark/sql/internal/BaseSessionStateBuilder.scala | 2 +- .../spark/sql/connector/V1ReadFallbackSuite.scala | 11 +++++------ .../spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 10 files changed, 27 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala b/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala index 7cda82ada0d2..fb2a03ad6c46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.read import org.apache.spark.annotation.{Experimental, Unstable} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.sources.BaseRelation /** @@ -39,5 +40,5 @@ trait V1Scan extends Scan { * * @since 3.0.0 */ - def toV1Relation(): BaseRelation + def toV1Relation(context: SQLContext): BaseRelation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 37be10e13e4e..0d759085a7e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -97,7 +97,8 @@ trait DataSourceScanExec extends LeafExecNode { /** Physical plan node for scanning data from a relation. */ case class RowDataSourceScanExec( - output: Seq[Attribute], + fullOutput: Seq[Attribute], + requiredColumnsIndex: Seq[Int], filters: Set[Filter], handledFilters: Set[Filter], rdd: RDD[InternalRow], @@ -105,6 +106,8 @@ case class RowDataSourceScanExec( override val tableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with InputRDDCodegen { + def output: Seq[Attribute] = requiredColumnsIndex.map(fullOutput) + override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -138,8 +141,7 @@ case class RowDataSourceScanExec( // Don't care about `rdd` and `tableIdentifier` when canonicalizing. override def doCanonicalize(): SparkPlan = copy( - // Only the required column names matter when checking equality. - output.map(a => a.withExprId(ExprId(-1))), + fullOutput.map(QueryPlan.normalizeExpressions(_, fullOutput)), rdd = null, tableIdentifier = None) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index dc7fb7741e7a..895eeedd86b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -27,7 +26,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.internal.SQLConf class SparkPlanner( - val sparkContext: SparkContext, + val session: SparkSession, val conf: SQLConf, val experimentalMethods: ExperimentalMethods) extends SparkStrategies { @@ -39,7 +38,7 @@ class SparkPlanner( extraPlanningStrategies ++ ( LogicalQueryStageStrategy :: PythonEvals :: - DataSourceV2Strategy :: + new DataSourceV2Strategy(session) :: FileSourceStrategy :: DataSourceStrategy(conf) :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 418401ac4e5c..00ad4e0fe0c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -570,7 +570,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) + protected lazy val singleRowRdd = session.sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index aada8cb73ccf..e3a0a0a6c34e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -296,6 +296,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with case l @ LogicalRelation(baseRelation: TableScan, _, _, _) => RowDataSourceScanExec( l.output, + l.output.indices, Set.empty, Set.empty, toCatalystRDD(l, baseRelation.buildScan()), @@ -367,7 +368,8 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with .map(relation.attributeMap) val scan = RowDataSourceScanExec( - requestedColumns, + relation.output, + requestedColumns.map(relation.output.indexOf), pushedFilters.toSet, handledFilters, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), @@ -388,7 +390,8 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq val scan = RowDataSourceScanExec( - requestedColumns, + relation.output, + requestedColumns.map(relation.output.indexOf), pushedFilters.toSet, handledFilters, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 58009a9ef675..e27518450625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -49,7 +49,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val (scanExec, needsUnsafeConversion) = relation.scan match { case v1Scan: V1Scan => - val v1Relation = v1Scan.toV1Relation() + val v1Relation = v1Scan.toV1Relation(session.sqlContext) if (v1Relation.schema != v1Scan.readSchema()) { throw new IllegalArgumentException( "The fallback v1 relation reports inconsistent schema:\n" + @@ -63,8 +63,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat "`V1Scan.toV1Relation` must return a `TableScan` instance.") } val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) + val originalOutputNames = relation.table.schema().map(_.name) + val requiredColumnsIndex = output.map(_.name).map(originalOutputNames.indexOf) val dsScan = RowDataSourceScanExec( output, + requiredColumnsIndex, pushedFilters.toSet, pushedFilters.toSet, unsafeRowRDD, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index b8e18b89b54b..bf80a0b1c167 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -50,7 +50,7 @@ class IncrementalExecution( // Modified planner with stateful operations. override val planner: SparkPlanner = new SparkPlanner( - sparkSession.sparkContext, + sparkSession, sparkSession.sessionState.conf, sparkSession.sessionState.experimentalMethods) { override def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 91c693ab34c8..eb658e2d8850 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -250,7 +250,7 @@ abstract class BaseSessionStateBuilder( * Note: this depends on the `conf` and `experimentalMethods` fields. */ protected def planner: SparkPlanner = { - new SparkPlanner(session.sparkContext, conf, experimentalMethods) { + new SparkPlanner(session, conf, experimentalMethods) { override def extraPlanningStrategies: Seq[Strategy] = super.extraPlanningStrategies ++ customPlanningStrategies } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 1da47c7a3892..8364e6abf6d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -110,7 +110,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { throw new UnsupportedOperationException } - val table = new TableWithV1ReadFallback + val table = new TableWithV1ReadFallback(ident.toString) tables.put(ident, table) table } @@ -122,12 +122,11 @@ object V1ReadFallbackCatalog { class V1ReadFallbackTableProvider extends TableProvider { override def getTable(options: CaseInsensitiveStringMap): Table = { - new TableWithV1ReadFallback + new TableWithV1ReadFallback("v1-read-fallback") } } -class TableWithV1ReadFallback extends Table with SupportsRead { - override def name(): String = "v1-read-fallback" +class TableWithV1ReadFallback(override val name: String) extends Table with SupportsRead { override def schema(): StructType = V1ReadFallbackCatalog.schema @@ -165,9 +164,9 @@ class TableWithV1ReadFallback extends Table with SupportsRead { requiredSchema: StructType, filters: Array[Filter]) extends V1Scan { override def readSchema(): StructType = requiredSchema - override def toV1Relation(): BaseRelation = { + override def toV1Relation(context: SQLContext): BaseRelation = { new BaseRelation with TableScan { - override def sqlContext: SQLContext = SparkSession.active.sqlContext + override def sqlContext: SQLContext = context override def schema: StructType = requiredSchema override def buildScan(): RDD[Row] = { val lowerBound = if (filters.isEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 3df77fec2099..de21a13e6edb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -97,7 +97,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Planner that takes into account Hive-specific strategies. */ override protected def planner: SparkPlanner = { - new SparkPlanner(session.sparkContext, conf, experimentalMethods) with HiveStrategies { + new SparkPlanner(session, conf, experimentalMethods) with HiveStrategies { override val sparkSession: SparkSession = session override def extraPlanningStrategies: Seq[Strategy] = From 9f34319fe8c5eeddfc4e28172eb39f3771544704 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 1 Nov 2019 21:34:36 +0800 Subject: [PATCH 3/9] fix --- .../sql/execution/datasources/v2/DataSourceV2Strategy.scala | 2 +- .../sql/execution/datasources/v2/V2ScanRelationPushDown.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index e27518450625..30ecc344fc5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -45,7 +45,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => val output = relation.output val pushedFilters = relation.getTagValue(V2ScanRelationPushDown.PUSHED_FILTERS_TAG) - .getOrElse(Array.empty) + .getOrElse(Seq.empty) val (scanExec, needsUnsafeConversion) = relation.scan match { case v1Scan: V1Scan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index c72f0edd7d77..d2f2baef254d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy object V2ScanRelationPushDown extends Rule[LogicalPlan] { import DataSourceV2Implicits._ - val PUSHED_FILTERS_TAG = TreeNodeTag[Array[org.apache.spark.sql.sources.Filter]]("pushed_filters") + val PUSHED_FILTERS_TAG = TreeNodeTag[Seq[org.apache.spark.sql.sources.Filter]]("pushed_filters") override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case ScanOperation(project, filters, relation: DataSourceV2Relation) => From 1e25565685aa810c1f4a380e24eacf56352e50e6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Dec 2019 17:10:07 +0800 Subject: [PATCH 4/9] address comment --- .../main/scala/org/apache/spark/sql/connector/read/V1Scan.scala | 2 +- .../sql/execution/datasources/v2/DataSourceV2Strategy.scala | 2 +- .../org/apache/spark/sql/connector/V1ReadFallbackSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala b/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala index fb2a03ad6c46..c87a244b8138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala @@ -40,5 +40,5 @@ trait V1Scan extends Scan { * * @since 3.0.0 */ - def toV1Relation(context: SQLContext): BaseRelation + def toV1TableScan(context: SQLContext): BaseRelation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 30ecc344fc5c..275ae0801048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -49,7 +49,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val (scanExec, needsUnsafeConversion) = relation.scan match { case v1Scan: V1Scan => - val v1Relation = v1Scan.toV1Relation(session.sqlContext) + val v1Relation = v1Scan.toV1TableScan(session.sqlContext) if (v1Relation.schema != v1Scan.readSchema()) { throw new IllegalArgumentException( "The fallback v1 relation reports inconsistent schema:\n" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 8364e6abf6d6..da9da4f107ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -164,7 +164,7 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp requiredSchema: StructType, filters: Array[Filter]) extends V1Scan { override def readSchema(): StructType = requiredSchema - override def toV1Relation(context: SQLContext): BaseRelation = { + override def toV1TableScan(context: SQLContext): BaseRelation = { new BaseRelation with TableScan { override def sqlContext: SQLContext = context override def schema: StructType = requiredSchema From d9ae863ea41bbce282e8451288e4cf5df3c91379 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Dec 2019 14:23:33 +0800 Subject: [PATCH 5/9] address comment --- .../datasources/v2/DataSourceV2Strategy.scala | 79 ++++++++++--------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 275ae0801048..5a02e214127f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -39,51 +39,56 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // projection and filters were already pushed down in the optimizer. - // this uses PhysicalOperation to get the projection and ensure that if the batch scan does - // not support columnar, a projection is added to convert the rows to UnsafeRow. - case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => - val output = relation.output + case PhysicalOperation(project, filters, + relation @ DataSourceV2ScanRelation(table, v1Scan: V1Scan, output)) => val pushedFilters = relation.getTagValue(V2ScanRelationPushDown.PUSHED_FILTERS_TAG) .getOrElse(Seq.empty) - - val (scanExec, needsUnsafeConversion) = relation.scan match { - case v1Scan: V1Scan => - val v1Relation = v1Scan.toV1TableScan(session.sqlContext) - if (v1Relation.schema != v1Scan.readSchema()) { - throw new IllegalArgumentException( - "The fallback v1 relation reports inconsistent schema:\n" + - "Schema of v2 scan: " + v1Scan.readSchema() + "\n" + - "Schema of v1 relation: " + v1Relation.schema) - } - val rdd = v1Relation match { - case s: TableScan => s.buildScan() - case _ => - throw new IllegalArgumentException( - "`V1Scan.toV1Relation` must return a `TableScan` instance.") - } - val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) - val originalOutputNames = relation.table.schema().map(_.name) - val requiredColumnsIndex = output.map(_.name).map(originalOutputNames.indexOf) - val dsScan = RowDataSourceScanExec( - output, - requiredColumnsIndex, - pushedFilters.toSet, - pushedFilters.toSet, - unsafeRowRDD, - v1Relation, - tableIdentifier = None) - (dsScan, false) + val v1Relation = v1Scan.toV1TableScan(session.sqlContext) + if (v1Relation.schema != v1Scan.readSchema()) { + throw new IllegalArgumentException( + "The fallback v1 relation reports inconsistent schema:\n" + + "Schema of v2 scan: " + v1Scan.readSchema() + "\n" + + "Schema of v1 relation: " + v1Relation.schema) + } + val rdd = v1Relation match { + case s: TableScan => s.buildScan() case _ => - val batchScan = BatchScanExec(output, relation.scan) - (batchScan, !batchScan.supportsColumnar) + throw new IllegalArgumentException( + "`V1Scan.toV1Relation` must return a `TableScan` instance.") } + val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) + val originalOutputNames = relation.table.schema().map(_.name) + val requiredColumnsIndex = output.map(_.name).map(originalOutputNames.indexOf) + val dsScan = RowDataSourceScanExec( + output, + requiredColumnsIndex, + pushedFilters.toSet, + pushedFilters.toSet, + unsafeRowRDD, + v1Relation, + tableIdentifier = None) + + val filterCondition = filters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, dsScan)).getOrElse(dsScan) + val withProjection = if (withFilter.output != project) { + ProjectExec(project, withFilter) + } else { + withFilter + } + + withProjection :: Nil + + case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => + // projection and filters were already pushed down in the optimizer. + // this uses PhysicalOperation to get the projection and ensure that if the batch scan does + // not support columnar, a projection is added to convert the rows to UnsafeRow. + val batchExec = BatchScanExec(relation.output, relation.scan) val filterCondition = filters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scanExec)).getOrElse(scanExec) + val withFilter = filterCondition.map(FilterExec(_, batchExec)).getOrElse(batchExec) - val withProjection = if (withFilter.output != project || needsUnsafeConversion) { + val withProjection = if (withFilter.output != project || !batchExec.supportsColumnar) { ProjectExec(project, withFilter) } else { withFilter From 938b60764ba4481e0e0dc184b8357cd0c62a1542 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Dec 2019 17:10:11 +0800 Subject: [PATCH 6/9] address comment --- .../datasources/v2/DataSourceV2Strategy.scala | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 5a02e214127f..df693db41ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -21,13 +21,13 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, SupportsNamespaces, TableCapability, TableCatalog, TableChange} import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.TableScan @@ -38,9 +38,24 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat import DataSourceV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + private def withProjectAndFilter( + project: Seq[NamedExpression], + filters: Seq[Expression], + scan: LeafExecNode, + needsUnsafeConversion: Boolean): SparkPlan = { + val filterCondition = filters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + if (withFilter.output != project || needsUnsafeConversion) { + ProjectExec(project, withFilter) + } else { + withFilter + } + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - relation @ DataSourceV2ScanRelation(table, v1Scan: V1Scan, output)) => + relation @ DataSourceV2ScanRelation(_, v1Scan: V1Scan, output)) => val pushedFilters = relation.getTagValue(V2ScanRelationPushDown.PUSHED_FILTERS_TAG) .getOrElse(Seq.empty) val v1Relation = v1Scan.toV1TableScan(session.sqlContext) @@ -67,34 +82,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat unsafeRowRDD, v1Relation, tableIdentifier = None) - - val filterCondition = filters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, dsScan)).getOrElse(dsScan) - - val withProjection = if (withFilter.output != project) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + withProjectAndFilter(project, filters, dsScan, needsUnsafeConversion = false) :: Nil case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => // projection and filters were already pushed down in the optimizer. // this uses PhysicalOperation to get the projection and ensure that if the batch scan does // not support columnar, a projection is added to convert the rows to UnsafeRow. val batchExec = BatchScanExec(relation.output, relation.scan) - - val filterCondition = filters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, batchExec)).getOrElse(batchExec) - - val withProjection = if (withFilter.output != project || !batchExec.supportsColumnar) { - ProjectExec(project, withFilter) - } else { - withFilter - } - - withProjection :: Nil + withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil case r: StreamingDataSourceV2Relation if r.startOffset.isDefined && r.endOffset.isDefined => val microBatchStream = r.stream.asInstanceOf[MicroBatchStream] From 16ad7ebba6335a3d1fd2b85cd42f0fd2bf52fa69 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Dec 2019 16:26:57 +0800 Subject: [PATCH 7/9] address comments --- .../spark/sql/connector/read/V1Scan.java} | 17 +++--- .../sql/connector/write/V1WriteBuilder.java} | 19 ++----- .../datasources/v2/DataSourceV2Strategy.scala | 11 ++-- .../sql/connector/V1ReadFallbackSuite.scala | 54 ++++++++++--------- 4 files changed, 46 insertions(+), 55 deletions(-) rename sql/core/src/main/{scala/org/apache/spark/sql/connector/read/V1Scan.scala => java/org/apache/spark/sql/connector/read/V1Scan.java} (72%) rename sql/core/src/main/{scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala => java/org/apache/spark/sql/connector/write/V1WriteBuilder.java} (73%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala b/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java similarity index 72% rename from sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala rename to sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java index c87a244b8138..816cea1b6f26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/connector/read/V1Scan.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.connector.read +package org.apache.spark.sql.connector.read; -import org.apache.spark.annotation.{Experimental, Unstable} -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.sources.BaseRelation; +import org.apache.spark.sql.sources.TableScan; /** * A trait that should be implemented by V1 DataSources that would like to leverage the DataSource @@ -30,15 +31,13 @@ * * @since 3.0.0 */ -@Experimental @Unstable -trait V1Scan extends Scan { +public interface V1Scan extends Scan { /** - * Creates an `BaseRelation` that can scan data from DataSource v1 to RDD[Row]. The returned - * relation must be a `TableScan` instance. + * Creates an `BaseRelation` with `TableScan` that can scan data from DataSource v1 to RDD[Row]. * * @since 3.0.0 */ - def toV1TableScan(context: SQLContext): BaseRelation + T toV1TableScan(SQLContext context); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java similarity index 73% rename from sql/core/src/main/scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala rename to sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java index e738ad1ede44..89b567b5231a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/connector/write/V1WriteBuilder.scala +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.connector.write +package org.apache.spark.sql.connector.write; -import org.apache.spark.annotation.{Experimental, Unstable} -import org.apache.spark.sql.connector.write.streaming.StreamingWrite -import org.apache.spark.sql.sources.InsertableRelation +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.sources.InsertableRelation; /** * A trait that should be implemented by V1 DataSources that would like to leverage the DataSource @@ -32,10 +31,8 @@ * * @since 3.0.0 */ -@Experimental @Unstable -trait V1WriteBuilder extends WriteBuilder { - +public interface V1WriteBuilder extends WriteBuilder { /** * Creates an InsertableRelation that allows appending a DataFrame to a * a destination (using data source-specific parameters). The insert method will only be @@ -44,11 +41,5 @@ * * @since 3.0.0 */ - def buildForV1Write(): InsertableRelation - - // These methods cannot be implemented by a V1WriteBuilder. The super class will throw - // an Unsupported OperationException - override final def buildForBatch(): BatchWrite = super.buildForBatch() - - override final def buildForStreaming(): StreamingWrite = super.buildForStreaming() + InsertableRelation buildForV1Write(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df693db41ce1..f7779f1eefb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBat import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} -import org.apache.spark.sql.sources.TableScan +import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.util.CaseInsensitiveStringMap class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { @@ -58,19 +58,14 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat relation @ DataSourceV2ScanRelation(_, v1Scan: V1Scan, output)) => val pushedFilters = relation.getTagValue(V2ScanRelationPushDown.PUSHED_FILTERS_TAG) .getOrElse(Seq.empty) - val v1Relation = v1Scan.toV1TableScan(session.sqlContext) + val v1Relation = v1Scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != v1Scan.readSchema()) { throw new IllegalArgumentException( "The fallback v1 relation reports inconsistent schema:\n" + "Schema of v2 scan: " + v1Scan.readSchema() + "\n" + "Schema of v1 relation: " + v1Relation.schema) } - val rdd = v1Relation match { - case s: TableScan => s.buildScan() - case _ => - throw new IllegalArgumentException( - "`V1Scan.toV1Relation` must return a `TableScan` instance.") - } + val rdd = v1Relation.buildScan() val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) val originalOutputNames = relation.table.schema().map(_.name) val requiredColumnsIndex = output.map(_.name).map(originalOutputNames.indexOf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index da9da4f107ce..8e2c63417b37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -164,30 +164,36 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp requiredSchema: StructType, filters: Array[Filter]) extends V1Scan { override def readSchema(): StructType = requiredSchema - override def toV1TableScan(context: SQLContext): BaseRelation = { - new BaseRelation with TableScan { - override def sqlContext: SQLContext = context - override def schema: StructType = requiredSchema - override def buildScan(): RDD[Row] = { - val lowerBound = if (filters.isEmpty) { - 0 - } else { - filters.collect { case GreaterThan("i", v: Int) => v }.max - } - val data = Seq(Row(1, 10), Row(2, 20), Row(3, 30)).filter(_.getInt(0) > lowerBound) - val result = if (requiredSchema.length == 2) { - data - } else if (requiredSchema.map(_.name) == Seq("i")) { - data.map(row => Row(row.getInt(0))) - } else if (requiredSchema.map(_.name) == Seq("j")) { - data.map(row => Row(row.getInt(1))) - } else { - throw new UnsupportedOperationException - } - - SparkSession.active.sparkContext.makeRDD(result) - } - } + + override def toV1TableScan[T <: BaseRelation with TableScan](context: SQLContext): T = { + new V1TableScan(context, requiredSchema, filters).asInstanceOf[T] } } } + +class V1TableScan( + context: SQLContext, + requiredSchema: StructType, + filters: Array[Filter]) extends BaseRelation with TableScan { + override def sqlContext: SQLContext = context + override def schema: StructType = requiredSchema + override def buildScan(): RDD[Row] = { + val lowerBound = if (filters.isEmpty) { + 0 + } else { + filters.collect { case GreaterThan("i", v: Int) => v }.max + } + val data = Seq(Row(1, 10), Row(2, 20), Row(3, 30)).filter(_.getInt(0) > lowerBound) + val result = if (requiredSchema.length == 2) { + data + } else if (requiredSchema.map(_.name) == Seq("i")) { + data.map(row => Row(row.getInt(0))) + } else if (requiredSchema.map(_.name) == Seq("j")) { + data.map(row => Row(row.getInt(1))) + } else { + throw new UnsupportedOperationException + } + + SparkSession.active.sparkContext.makeRDD(result) + } +} From 7e69bfb03fe6aefc4f9762246fff02133e6eb3e9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 13 Jan 2020 12:38:12 +0800 Subject: [PATCH 8/9] address comments --- .../main/java/org/apache/spark/sql/connector/read/V1Scan.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java b/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java index 816cea1b6f26..c9d7cb1bf80a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/read/V1Scan.java @@ -35,7 +35,7 @@ public interface V1Scan extends Scan { /** - * Creates an `BaseRelation` with `TableScan` that can scan data from DataSource v1 to RDD[Row]. + * Create an `BaseRelation` with `TableScan` that can scan data from DataSource v1 to RDD[Row]. * * @since 3.0.0 */ From a48e7bb31bdab8abcd2b6ba9aa37faa464351f61 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 Jan 2020 16:06:55 +0800 Subject: [PATCH 9/9] use a wrapper --- .../datasources/v2/DataSourceV2Strategy.scala | 15 +++++------ .../v2/V2ScanRelationPushDown.scala | 25 +++++++++++++++---- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f7779f1eefb0..568ffba4854c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpressi import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.{StagingTableCatalog, SupportsNamespaces, TableCapability, TableCatalog, TableChange} -import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -55,14 +54,12 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, - relation @ DataSourceV2ScanRelation(_, v1Scan: V1Scan, output)) => - val pushedFilters = relation.getTagValue(V2ScanRelationPushDown.PUSHED_FILTERS_TAG) - .getOrElse(Seq.empty) - val v1Relation = v1Scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) - if (v1Relation.schema != v1Scan.readSchema()) { + relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated, pushed), output)) => + val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) + if (v1Relation.schema != scan.readSchema()) { throw new IllegalArgumentException( "The fallback v1 relation reports inconsistent schema:\n" + - "Schema of v2 scan: " + v1Scan.readSchema() + "\n" + + "Schema of v2 scan: " + scan.readSchema() + "\n" + "Schema of v1 relation: " + v1Relation.schema) } val rdd = v1Relation.buildScan() @@ -72,8 +69,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val dsScan = RowDataSourceScanExec( output, requiredColumnsIndex, - pushedFilters.toSet, - pushedFilters.toSet, + translated.toSet, + pushed.toSet, unsafeRowRDD, v1Relation, tableIdentifier = None) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d2f2baef254d..59089fa6b77e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -21,14 +21,14 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpressi import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.connector.read.{Scan, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.types.StructType object V2ScanRelationPushDown extends Rule[LogicalPlan] { import DataSourceV2Implicits._ - val PUSHED_FILTERS_TAG = TreeNodeTag[Seq[org.apache.spark.sql.sources.Filter]]("pushed_filters") - override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case ScanOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.table.asReadable.newScanBuilder(relation.options) @@ -57,8 +57,14 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { |Output: ${output.mkString(", ")} """.stripMargin) - val scanRelation = DataSourceV2ScanRelation(relation.table, scan, output) - scanRelation.setTagValue(PUSHED_FILTERS_TAG, pushedFilters) + val wrappedScan = scan match { + case v1: V1Scan => + val translated = filters.flatMap(DataSourceStrategy.translateFilter) + V1ScanWrapper(v1, translated, pushedFilters) + case _ => scan + } + + val scanRelation = DataSourceV2ScanRelation(relation.table, wrappedScan, output) val projectionOverSchema = ProjectionOverSchema(output.toStructType) val projectionFunc = (expr: Expression) => expr transformDown { @@ -81,3 +87,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { withProjection } } + +// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by +// the physical v1 scan node. +case class V1ScanWrapper( + v1Scan: V1Scan, + translatedFilters: Seq[sources.Filter], + handledFilters: Seq[sources.Filter]) extends Scan { + override def readSchema(): StructType = v1Scan.readSchema() +}