diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 16b22717b8d9..e55a0f5ab761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan @@ -51,6 +51,15 @@ import org.apache.spark.sql.execution.SparkPlan */ object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p @ Project(fields, child) + if !fields.forall(_.deterministic) && p.references.nonEmpty => + collectFileSource(Project(child.output.filter(p.references.contains), child)) + .map(p => execution.ProjectExec(fields, p)).toList + + case _ => collectFileSource(plan) + } + + private def collectFileSource(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 13341645e8ff..30a6e1fe3900 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2034,6 +2034,28 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-21520: the fields of project contains nondeterministic") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + val p = path.getAbsolutePath + Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c").write.partitionBy("a").parquet(p) + val df = spark.read.parquet(p) + + val qe = df.select($"a", rand(10).as('rand)) + // FileScan parquet [a#38] + assert(qe.queryExecution.sparkPlan.inputSet.toString.contains("a#")) + assert(!qe.queryExecution.sparkPlan.inputSet.toString.contains("b#")) + assert(!qe.queryExecution.sparkPlan.inputSet.toString.contains("c#")) + + val qe2 = df.select($"a", $"b", rand(10).as('rand2)) + // FileScan parquet [b#70,a#72] + assert(qe2.queryExecution.sparkPlan.inputSet.toString.contains("a#")) + assert(qe2.queryExecution.sparkPlan.inputSet.toString.contains("b#")) + assert(!qe2.queryExecution.sparkPlan.inputSet.toString.contains("c#")) + } + } + } + test("order-by ordinal.") { checkAnswer( testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index caf554d9ea51..d9abf1f3a236 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -26,8 +26,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, - ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} @@ -239,6 +238,15 @@ private[hive] trait HiveStrategies { */ object HiveTableScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p @ Project(fields, child) + if !fields.forall(_.deterministic) && p.references.nonEmpty => + collectHiveTableSource(Project(child.output.filter(p.references.contains), child)) + .map(p => ProjectExec(fields, p)).toList + + case _ => collectHiveTableSource(plan) + } + + private def collectHiveTableSource(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning.