diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index a6be98c8a3aa..69d30dd5048d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._ * field indexes and field counts of complex type extractors and attributes * are adjusted to fit the schema. All other expressions are left as-is. This * class is motivated by columnar nested schema pruning. + * + * @param schema nested column schema + * @param output output attributes of the data source relation. They are used to filter out + * attributes in the schema that do not belong to the current relation. */ -case class ProjectionOverSchema(schema: StructType) { +case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if fieldNames.contains(a.name) => + case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal, failOnError) => getProjection(child).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bb788336c6d7..2dc775e8e667 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val excludedOnceBatches: Set[String] = Set( "PartitionPruning", + "RewriteSubquery", "Extract Python UDFs") protected def fixedPoint = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 82aef32c5a22..3387bb200770 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } // Builds new projection. - val projectionOverSchema = ProjectionOverSchema(prunedSchema) + val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output)) val newProjects = p.projectList.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index a49c10c852b0..26d5d92fecb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -91,8 +91,8 @@ object SchemaPruning extends Rule[LogicalPlan] { if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) || countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) { val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema) - val projectionOverSchema = - ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema)) + val projectionOverSchema = ProjectionOverSchema( + prunedDataSchema.merge(prunedMetadataSchema), AttributeSet(relation.output)) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 6455e2508927..b7e0531989f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation @@ -320,7 +319,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionOverSchema = + ProjectionOverSchema(output.toStructType, AttributeSet(output)) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 4eb8258830ce..3c715ca602a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -61,11 +61,15 @@ abstract class SchemaPruningSuite override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false") + case class Employee(id: Int, name: FullName, employer: Company) + val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") - val employer = Employer(0, Company("abc", "123 Business Street")) + val company = Company("abc", "123 Business Street") + + val employer = Employer(0, company) val employerWithNullCompany = Employer(1, null) val employerWithNullCompany2 = Employer(2, null) @@ -81,6 +85,8 @@ abstract class SchemaPruningSuite Department(1, "Marketing", 1, employerWithNullCompany) :: Department(2, "Operation", 4, employerWithNullCompany2) :: Nil + val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil + case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -621,6 +627,26 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") { + withContacts { + withEmployees { + val query = sql( + """ + |select count(*) + |from contacts c + |where not exists (select null from employees e where e.name.first = c.name.first + | and e.employer.name = c.employer.company.name) + |""".stripMargin) + checkScan(query, + "struct," + + "employer:struct>>", + "struct," + + "employer:struct>") + checkAnswer(query, Row(3)) + } + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { @@ -701,6 +727,23 @@ abstract class SchemaPruningSuite } } + private def withEmployees(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(employees, new File(path + "/employees")) + + // Providing user specified schema. Inferred schema from different data sources might + // be different. + val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " + + "`employer` STRUCT<`name`: STRING, `address`: STRING>" + spark.read.format(dataSourceName).schema(schema).load(path + "/employees") + .createOrReplaceTempView("employees") + + testThunk + } + } + case class MixedCaseColumn(a: String, B: Int) case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)