diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala index 1daf8ae72b63..d053ea98f8b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FilePartitionReaderFactory.scala @@ -46,19 +46,6 @@ abstract class FilePartitionReaderFactory extends PartitionReaderFactory { def buildColumnarReader(partitionedFile: PartitionedFile): PartitionReader[ColumnarBatch] = { throw new UnsupportedOperationException("Cannot create columnar reader.") } - - protected def getReadDataSchema( - readSchema: StructType, - partitionSchema: StructType, - isCaseSensitive: Boolean): StructType = { - val partitionNameSet = - partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - val fields = readSchema.fields.filterNot { field => - partitionNameSet.contains(PartitioningUtils.getColName(field, isCaseSensitive)) - } - - StructType(fields) - } } // A compound class for combining file and its corresponding reader. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index e971fd762efe..337aac9ea651 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -16,9 +16,13 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.util.Locale + import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} @@ -28,8 +32,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class FileScan( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - readSchema: StructType, - options: CaseInsensitiveStringMap) extends Scan with Batch { + readDataSchema: StructType, + readPartitionSchema: StructType) extends Scan with Batch { /** * Returns whether a file with `path` could be split or not. */ @@ -40,7 +44,23 @@ abstract class FileScan( protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) + val partitionAttributes = fileIndex.partitionSchema.toAttributes + val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap + val readPartitionAttributes = readPartitionSchema.map { readField => + attributeMap.get(normalizeName(readField.name)).getOrElse { + throw new AnalysisException(s"Can't find required partition column ${readField.name} " + + s"in partition schema ${fileIndex.partitionSchema}") + } + } + lazy val partitionValueProject = + GenerateUnsafeProjection.generate(readPartitionAttributes, partitionAttributes) val splitFiles = selectedPartitions.flatMap { partition => + // Prune partition values if part of the partition columns are not required. + val partitionValues = if (readPartitionAttributes != partitionAttributes) { + partitionValueProject(partition.values).copy() + } else { + partition.values + } partition.files.flatMap { file => val filePath = file.getPath PartitionedFileUtil.splitFiles( @@ -49,7 +69,7 @@ abstract class FileScan( filePath = filePath, isSplitable = isSplitable(filePath), maxSplitBytes = maxSplitBytes, - partitionValues = partition.values + partitionValues = partitionValues ) }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) } @@ -61,4 +81,17 @@ abstract class FileScan( } override def toBatch: Batch = this + + override def readSchema(): StructType = + StructType(readDataSchema.fields ++ readPartitionSchema.fields) + + private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + + private def normalizeName(name: String): String = { + if (isCaseSensitive) { + name + } else { + name.toLowerCase(Locale.ROOT) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index d4e55a50307d..3b236be90e6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -16,15 +16,44 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.sources.v2.reader.{ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.sources.v2.reader.{ScanBuilder, SupportsPushDownRequiredColumns} import org.apache.spark.sql.types.StructType -abstract class FileScanBuilder(schema: StructType) - extends ScanBuilder - with SupportsPushDownRequiredColumns { - protected var readSchema = schema +abstract class FileScanBuilder( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns { + private val partitionSchema = fileIndex.partitionSchema + private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) override def pruneColumns(requiredSchema: StructType): Unit = { - this.readSchema = requiredSchema + this.requiredSchema = requiredSchema } + + protected def readDataSchema(): StructType = { + val requiredNameSet = createRequiredNameSet() + val fields = dataSchema.fields.filter { field => + val colName = PartitioningUtils.getColName(field, isCaseSensitive) + requiredNameSet.contains(colName) && !partitionNameSet.contains(colName) + } + StructType(fields) + } + + protected def readPartitionSchema(): StructType = { + val requiredNameSet = createRequiredNameSet() + val fields = partitionSchema.fields.filter { field => + val colName = PartitioningUtils.getColName(field, isCaseSensitive) + requiredNameSet.contains(colName) + } + StructType(fields) + } + + private def createRequiredNameSet(): Set[String] = + requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet + + private val partitionNameSet: Set[String] = + partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala index 8d9cc68417ef..d6b84dcdfd15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TextBasedFileScan.scala @@ -29,9 +29,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class TextBasedFileScan( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - readSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readSchema, options) { + extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { private var codecFactory: CompressionCodecFactory = _ override def isSplitable(path: Path): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index e2d50282e9cb..28e310489cd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -33,24 +33,21 @@ import org.apache.spark.util.SerializableConfiguration * @param sqlConf SQL configuration. * @param broadcastedConf Broadcasted serializable Hadoop Configuration. * @param dataSchema Schema of CSV files. + * @param readDataSchema Required data schema in the batch scan. * @param partitionSchema Schema of partitions. - * @param readSchema Required schema in the batch scan. * @param parsedOptions Options for parsing CSV files. */ case class CSVPartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, + readDataSchema: StructType, partitionSchema: StructType, - readSchema: StructType, parsedOptions: CSVOptions) extends FilePartitionReaderFactory { private val columnPruning = sqlConf.csvColumnPruning - private val readDataSchema = - getReadDataSchema(readSchema, partitionSchema, sqlConf.caseSensitiveAnalysis) override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value - val parser = new UnivocityParser( StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), StructType(readDataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 8f2f8f256731..5bc8029b4068 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -35,9 +35,10 @@ case class CSVScan( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, - readSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, options: CaseInsensitiveStringMap) - extends TextBasedFileScan(sparkSession, fileIndex, readSchema, options) { + extends TextBasedFileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema, options) { private lazy val parsedOptions: CSVOptions = new CSVOptions( options.asScala.toMap, @@ -53,8 +54,8 @@ case class CSVScan( // Check a field requirement for corrupt records here to throw an exception in a driver side ExprUtils.verifyColumnNameOfCorruptRecord(dataSchema, parsedOptions.columnNameOfCorruptRecord) - if (readSchema.length == 1 && - readSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { + if (readDataSchema.length == 1 && + readDataSchema.head.name == parsedOptions.columnNameOfCorruptRecord) { throw new AnalysisException( "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" + "referenced columns only include the internal corrupt record column\n" + @@ -72,7 +73,9 @@ case class CSVScan( val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions) + dataSchema, readDataSchema, readPartitionSchema, parsedOptions) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index dbb3c03ca981..28c5b3d81a3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -29,9 +29,10 @@ case class CSVScanBuilder( fileIndex: PartitioningAwareFileIndex, schema: StructType, dataSchema: StructType, - options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) { + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - CSVScan(sparkSession, fileIndex, dataSchema, readSchema, options) + CSVScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 1da9469909f1..ec923797e269 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -46,30 +46,30 @@ import org.apache.spark.util.SerializableConfiguration * @param sqlConf SQL configuration. * @param broadcastedConf Broadcast serializable Hadoop Configuration. * @param dataSchema Schema of orc files. + * @param readDataSchema Required data schema in the batch scan. * @param partitionSchema Schema of partitions. - * @param readSchema Required schema in the batch scan. */ case class OrcPartitionReaderFactory( sqlConf: SQLConf, broadcastedConf: Broadcast[SerializableConfiguration], dataSchema: StructType, - partitionSchema: StructType, - readSchema: StructType) extends FilePartitionReaderFactory { + readDataSchema: StructType, + partitionSchema: StructType) extends FilePartitionReaderFactory { + private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && - readSchema.length <= sqlConf.wholeStageMaxNumFields && - readSchema.forall(_.dataType.isInstanceOf[AtomicType]) + resultSchema.length <= sqlConf.wholeStageMaxNumFields && + resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value - val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive) - val readDataSchemaString = OrcUtils.orcTypeDescriptionString(readDataSchema) - OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readDataSchemaString) + val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema) + OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) val filePath = new Path(new URI(file.filePath)) @@ -113,8 +113,8 @@ case class OrcPartitionReaderFactory( override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { val conf = broadcastedConf.value.value - val readSchemaString = OrcUtils.orcTypeDescriptionString(readSchema) - OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readSchemaString) + val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema) + OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) val filePath = new Path(new URI(file.filePath)) @@ -124,13 +124,13 @@ case class OrcPartitionReaderFactory( val reader = OrcFile.createReader(filePath, readerOptions) val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, readSchema, reader, conf) + isCaseSensitive, dataSchema, readDataSchema, reader, conf) if (requestedColIdsOrEmptyFile.isEmpty) { new EmptyPartitionReader } else { - val requestedColIds = requestedColIdsOrEmptyFile.get - assert(requestedColIds.length == readSchema.length, + val requestedColIds = requestedColIdsOrEmptyFile.get ++ Array.fill(partitionSchema.length)(-1) + assert(requestedColIds.length == resultSchema.length, "[BUG] requested column IDs do not match required schema") val taskConf = new Configuration(conf) @@ -140,15 +140,12 @@ case class OrcPartitionReaderFactory( val batchReader = new OrcColumnarBatchReader(capacity) batchReader.initialize(fileSplit, taskAttemptContext) - val columnNameMap = partitionSchema.fields.map( - PartitioningUtils.getColName(_, isCaseSensitive)).zipWithIndex.toMap - val requestedPartitionColIds = readSchema.fields.map { field => - columnNameMap.getOrElse(PartitioningUtils.getColName(field, isCaseSensitive), -1) - } + val requestedPartitionColIds = + Array.fill(readDataSchema.length)(-1) ++ Range(0, partitionSchema.length) batchReader.initBatch( - TypeDescription.fromString(readSchemaString), - readSchema.fields, + TypeDescription.fromString(resultSchemaString), + resultSchema.fields, requestedColIds, requestedPartitionColIds, file.partitionValues) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index fc8a682b226c..dc6b67ceb7e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -32,15 +32,18 @@ case class OrcScan( hadoopConf: Configuration, fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, - readSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScan(sparkSession, fileIndex, readSchema, options) { + extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, fileIndex.partitionSchema, readSchema) + dataSchema, readDataSchema, readPartitionSchema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 8ac56aa5f64b..4c1ec520c6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -36,7 +36,7 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(schema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -44,7 +44,8 @@ case class OrcScanBuilder( } override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema, options) + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, + readDataSchema(), readPartitionSchema(), options) } private var _pushedFilters: Array[Filter] = Array.empty