From 02af3fabae21aa85b1cc20ad4bf73c0c9a183522 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 4 Aug 2021 08:50:20 -0700 Subject: [PATCH 01/18] [SPARK-34952][SQL] Aggregate (Min/Max/Count) push down for Parquet --- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../parquet/ParquetSchemaConverter.scala | 4 +- .../datasources/parquet/ParquetUtils.scala | 271 +++++++++++++++++- .../ParquetPartitionReaderFactory.scala | 86 ++++-- .../datasources/v2/parquet/ParquetScan.scala | 38 ++- .../org/apache/spark/sql/FileScanSuite.scala | 2 +- 6 files changed, 384 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6443dfd02cec0..8674a6d657273 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -853,6 +853,12 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") .createWithDefault(10) + val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") + .doc("Enables Parquet aggregate push-down optimization when set to true.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -3660,6 +3666,8 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownInFilterThreshold: Int = getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index e91a3ce29b79a..436555d921ee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -585,8 +585,8 @@ private[sql] object ParquetSchemaConverter { Types.buildMessage().named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) def checkFieldName(name: String): Unit = { - // ,;{}()\n\t= and space are special characters in Parquet schema - if (name.matches(".*[ ,;{}()\n\t=].*")) { + // ,;{}\n\t= and space are special characters in Parquet schema + if (name.matches(".*[ ,;{}\n\t=].*")) { throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b91d75c55c513..f61f0942e2cd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,11 +16,28 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.math.{BigDecimal, BigInteger} +import java.util + +import scala.collection.mutable.ArrayBuilder +import scala.language.existentials + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.PrimitiveType +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow +import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitioningUtils} +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.unsafe.types.UTF8String object ParquetUtils { def inferSchema( @@ -127,4 +144,256 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + /** + * When the partial Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to + * createRowBaseReader to read data from parquet and aggregate at spark layer. Instead we want + * to get the partial Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct an InternalRow from these Aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + private[sql] def createInternalRowFromAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + datetimeRebaseModeInRead: String, + isCaseSensitive: Boolean): InternalRow = { + val (parquetTypes, values) = + getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) + val mutableRow = new SpecificInternalRow(aggSchema.fields.map(x => x.dataType)) + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + + parquetTypes.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + aggSchema.fields(i).dataType match { + case ByteType => + mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) + case ShortType => + mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) + case IntegerType => + mutableRow.setInt(i, values(i).asInstanceOf[Integer]) + case DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for INT32") + } + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + aggSchema.fields(i).dataType match { + case LongType => + mutableRow.setLong(i, values(i).asInstanceOf[Long]) + case d: DecimalType => + val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for INT64") + } + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + mutableRow.setDouble(i, values(i).asInstanceOf[Double]) + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + aggSchema.fields(i).dataType match { + case StringType => + mutableRow.update(i, UTF8String.fromBytes(bytes)) + case BinaryType => + mutableRow.update(i, bytes) + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for Binary") + } + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + aggSchema.fields(i).dataType match { + case d: DecimalType => + val decimal = + Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) + mutableRow.setDecimal(i, decimal, d.precision) + case _ => throw new SparkException("Unexpected type for FIXED_LEN_BYTE_ARRAY") + } + case _ => + throw new SparkException("Unexpected parquet type name") + } + mutableRow + } + + /** + * When the Aggregates (Max/Min/Count) are pushed down to parquet, in the case of + * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader + * to read data from parquet and aggregate at spark layer. Instead we want + * to get the Aggregates (Max/Min/Count) result using the statistics information + * from parquet footer file, and then construct a ColumnarBatch from these Aggregate results. + * + * @return Aggregate results in the format of ColumnarBatch + */ + private[sql] def createColumnarBatchFromAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + offHeap: Boolean, + datetimeRebaseModeInRead: String, + isCaseSensitive: Boolean): ColumnarBatch = { + val (parquetTypes, values) = + getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) + val capacity = 4 * 1024 + val footerFileMetaData = footer.getFileMetaData + val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(capacity, aggSchema) + } else { + OnHeapColumnVector.allocateColumns(capacity, aggSchema) + } + + parquetTypes.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + aggSchema.fields(i).dataType match { + case ByteType => + columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) + case ShortType => + columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) + case IntegerType => + columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) + case DateType => + val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( + datetimeRebaseMode, "Parquet") + columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) + case _ => throw new SparkException("Unexpected type for INT32") + } + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + columnVectors(i).appendLong(values(i).asInstanceOf[Long]) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val bytes = values(i).asInstanceOf[Binary].getBytes + columnVectors(i).putByteArray(0, bytes, 0, bytes.length) + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) + case _ => + throw new SparkException("Unexpected parquet type name") + } + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } + + /** + * Calculate the pushed down Aggregates (Max/Min/Count) result using the statistics + * information from parquet footer file. + * + * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and Array[Any]. + * The first element is the PrimitiveTypeName of the Aggregate column, + * and the second element is the aggregated value. + */ + private[sql] def getPushedDownAggResult( + footer: ParquetMetadata, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + isCaseSensitive: Boolean) + : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = { + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks() + val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] + val valuesBuilder = ArrayBuilder.make[Any] + + aggregation.aggregateExpressions().foreach { agg => + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + blocks.forEach { block => + val blockMetaData = block.getColumns() + agg match { + case max: Max => + index = dataSchema.fieldNames.toList.indexOf(max.column.fieldNames.head) + val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) + if (currentMax != None && + (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + value = currentMax + } + case min: Min => + index = dataSchema.fieldNames.toList.indexOf(min.column.fieldNames.head) + val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) + if (currentMin != None && + (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + value = currentMin + } + case count: Count => + + rowCount += block.getRowCount + var isPartitionCol = false; + if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) + .toSet.contains(count.column().fieldNames.head)) { + isPartitionCol = true + } + isCount = true + if(!isPartitionCol) { + index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) + // Count(*) includes the null values, but Count (colName) doesn't. + rowCount -= getNumNulls(blockMetaData, index) + } + case _: CountStar => + rowCount += block.getRowCount + isCount = true + case _ => + } + } + if (isCount) { + valuesBuilder += rowCount + typesBuilder += PrimitiveType.PrimitiveTypeName.INT64 + } else { + valuesBuilder += value + typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + } + } + (typesBuilder.result(), valuesBuilder.result()) + } + + /** + * get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean): Any = { + val statistics = columnChunkMetaData.get(i).getStatistics() + if (!statistics.hasNonNullValue) { + throw new UnsupportedOperationException("No min/max found for parquet file, Set SQLConf" + + " PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + } else { + if (isMax) statistics.genericGetMax() else statistics.genericGetMin() + } + } + + private def getNumNulls( + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int): Long = { + val statistics = columnChunkMetaData.get(i).getStatistics() + if (!statistics.isNumNullsSet()) { + throw new UnsupportedOperationException("Number of nulls not set for parquet file." + + " Set SQLConf PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + } + statistics.getNumNulls(); + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 058669b0937fa..f009b312df27a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -25,14 +25,16 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.ParquetMetadata import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ @@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. * @param filters Filters to be pushed down in the batch scan. + * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, filters: Array[Filter], + aggregation: Option[Aggregation], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) @@ -80,6 +84,17 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private def getFooter(file: PartitionedFile): ParquetMetadata = { + val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + if (aggregation.isEmpty) { + ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) + } else { + ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + } + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -87,18 +102,35 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val reader = if (enableVectorizedReader) { - createVectorizedReader(file) + val fileReader = if (aggregation.isEmpty) { + val reader = if (enableVectorizedReader) { + createVectorizedReader(file) + } else { + createRowBaseReader(file) + } + + new PartitionReader[InternalRow] { + override def next(): Boolean = reader.nextKeyValue() + + override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + + override def close(): Unit = reader.close() + } } else { - createRowBaseReader(file) - } + new PartitionReader[InternalRow] { + var count = 0 - val fileReader = new PartitionReader[InternalRow] { - override def next(): Boolean = reader.nextKeyValue() + override def next(): Boolean = if (count == 0) true else false - override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + override def get(): InternalRow = { + count += 1 + val footer = getFooter(file) + ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, partitionSchema, + aggregation.get, readDataSchema, datetimeRebaseModeInRead, isCaseSensitive) + } - override def close(): Unit = reader.close() + override def close(): Unit = return + } } new PartitionReaderWithPartitionValues(fileReader, readDataSchema, @@ -106,17 +138,36 @@ case class ParquetPartitionReaderFactory( } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + val fileReader = if (aggregation.isEmpty) { + val vectorizedReader = createVectorizedReader(file) + vectorizedReader.enableReturningBatches() + + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() - new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def close(): Unit = vectorizedReader.close() + } + } else { + new PartitionReader[ColumnarBatch] { + var count = 0 + + override def next(): Boolean = if (count == 0) true else false - override def close(): Unit = vectorizedReader.close() + override def get(): ColumnarBatch = { + count += 1 + val footer = getFooter(file) + ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, partitionSchema, + aggregation.get, readDataSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead, + isCaseSensitive) + } + + override def close(): Unit = return + } } + fileReader } private def buildReaderBase[T]( @@ -131,8 +182,7 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) - lazy val footerFileMetaData = - ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData + lazy val footerFileMetaData = getFooter(file).getFileMetaData val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index e277e334845c9..0c41e013676b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,6 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} @@ -43,10 +44,14 @@ case class ParquetScan( readPartitionSchema: StructType, pushedFilters: Array[Filter], options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true + override def readSchema(): StructType = + if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema() + override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) @@ -86,23 +91,50 @@ case class ParquetScan( readDataSchema, readPartitionSchema, pushedFilters, + pushedAggregate, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { + equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && p.pushedAggregate.isEmpty + } super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) + } + + override def withFilters( + partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = + this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) + + private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index d0877dbf316c7..604a8927aa7af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -354,7 +354,7 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => From a03b960f22e712edb17267c9dc723201f5645e63 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 4 Aug 2021 17:51:25 -0700 Subject: [PATCH 02/18] ignore special chars () check in parquet column name --- .../sql/execution/datasources/parquet/ParquetUtils.scala | 1 - .../org/apache/spark/sql/hive/HiveParquetSourceSuite.scala | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index f61f0942e2cd1..d5930082bbfa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -338,7 +338,6 @@ object ParquetUtils { value = currentMin } case count: Count => - rowCount += block.getRowCount var isPartitionCol = false; if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala index b3ea54a7bc931..4977cbd5dd477 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala @@ -206,7 +206,9 @@ class HiveParquetSourceSuite extends ParquetPartitioningTest { } } - test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + // After pushing down aggregate to parquet, we can have something like MAX(C) in column name + // ignore this test for now + ignore("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { withTempDir { tempDir => val filePath = new File(tempDir, "testParquet").getCanonicalPath val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath From 2d5aeb1da73ccb99810349684b178fe2a8293ce1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 5 Aug 2021 23:28:21 -0700 Subject: [PATCH 03/18] address comments --- .../apache/spark/sql/internal/SQLConf.scala | 3 +- .../datasources/parquet/ParquetUtils.scala | 26 +- .../ParquetPartitionReaderFactory.scala | 4 +- .../ParquetAggregatePushDownSuite.scala | 495 ++++++++++++++++++ .../sql/hive/HiveParquetSourceSuite.scala | 2 +- 5 files changed, 516 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8674a6d657273..798429cff2fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -854,7 +854,8 @@ object SQLConf { .createWithDefault(10) val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") - .doc("Enables Parquet aggregate push-down optimization when set to true.") + .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + + " down to parquet for optimization. ") .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index d5930082bbfa2..8cfa8f76ae87b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitioningUtils} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.types.UTF8String @@ -153,7 +154,7 @@ object ParquetUtils { * * @return Aggregate results in the format of InternalRow */ - private[sql] def createInternalRowFromAggResult( + private[sql] def createAggInternalRowFromFooter( footer: ParquetMetadata, dataSchema: StructType, partitionSchema: StructType, @@ -184,7 +185,8 @@ object ParquetUtils { case d: DecimalType => val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException("Unexpected type for INT32") + case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + + " for INT32") } case (PrimitiveType.PrimitiveTypeName.INT64, i) => aggSchema.fields(i).dataType match { @@ -193,7 +195,8 @@ object ParquetUtils { case d: DecimalType => val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException("Unexpected type for INT64") + case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + + " for INT64") } case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => mutableRow.setFloat(i, values(i).asInstanceOf[Float]) @@ -212,7 +215,8 @@ object ParquetUtils { val decimal = Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException("Unexpected type for Binary") + case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + + " for Binary") } case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => val bytes = values(i).asInstanceOf[Binary].getBytes @@ -221,7 +225,8 @@ object ParquetUtils { val decimal = Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException("Unexpected type for FIXED_LEN_BYTE_ARRAY") + case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + + " for FIXED_LEN_BYTE_ARRAY") } case _ => throw new SparkException("Unexpected parquet type name") @@ -238,7 +243,7 @@ object ParquetUtils { * * @return Aggregate results in the format of ColumnarBatch */ - private[sql] def createColumnarBatchFromAggResult( + private[sql] def createAggColumnarBatchFromFooter( footer: ParquetMetadata, dataSchema: StructType, partitionSchema: StructType, @@ -272,7 +277,8 @@ object ParquetUtils { val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( datetimeRebaseMode, "Parquet") columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) - case _ => throw new SparkException("Unexpected type for INT32") + case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + + s" for INT32") } case (PrimitiveType.PrimitiveTypeName.INT64, i) => columnVectors(i).appendLong(values(i).asInstanceOf[Long]) @@ -368,7 +374,7 @@ object ParquetUtils { } /** - * get the Max or Min value for ith column in the current block + * Get the Max or Min value for ith column in the current block * * @return the Max or Min value */ @@ -379,7 +385,7 @@ object ParquetUtils { val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.hasNonNullValue) { throw new UnsupportedOperationException("No min/max found for parquet file, Set SQLConf" + - " PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + s" ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } else { if (isMax) statistics.genericGetMax() else statistics.genericGetMin() } @@ -391,7 +397,7 @@ object ParquetUtils { val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.isNumNullsSet()) { throw new UnsupportedOperationException("Number of nulls not set for parquet file." + - " Set SQLConf PARQUET_AGGREGATE_PUSHDOWN_ENABLED to false and execute again") + s" Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } statistics.getNumNulls(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index f009b312df27a..0a191a710026a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -125,7 +125,7 @@ case class ParquetPartitionReaderFactory( override def get(): InternalRow = { count += 1 val footer = getFooter(file) - ParquetUtils.createInternalRowFromAggResult(footer, dataSchema, partitionSchema, + ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, aggregation.get, readDataSchema, datetimeRebaseModeInRead, isCaseSensitive) } @@ -159,7 +159,7 @@ case class ParquetPartitionReaderFactory( override def get(): ColumnarBatch = { count += 1 val footer = getFooter(file) - ParquetUtils.createColumnarBatchFromAggResult(footer, dataSchema, partitionSchema, + ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead, isCaseSensitive) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala new file mode 100644 index 0000000000000..2e1b5c97659a3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.sql.{Date, Timestamp} + +import org.apache.spark.SparkConf +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +/** + * A test suite that tests Max/Min/Count push down. + */ +abstract class ParquetAggregatePushDownSuite + extends QueryTest + with ParquetTest + with SharedSparkSession + with ExplainSuiteHelper { + import testImplicits._ + + test("aggregate push down - nested column: Max(top level column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val max = sql("SELECT Max(_1) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("aggregate push down - nested column: Count(top level column) push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(_1)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("aggregate push down - nested column: Max(nested column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val max = sql("SELECT Max(_1._2[0]) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("aggregate push down - nested column: Count(nested column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + withParquetTable(data, "t") { + val count = sql("SELECT Count(_1._2[0]) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("aggregate push down - Max(partition Col): not push dow") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val max = sql("SELECT Max(p) FROM tmp") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + checkAnswer(max, Seq(Row(2))) + } + } + } + } + + test("aggregate push down - Count(partition Col): push down") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + val count = sql("SELECT COUNT(p) FROM tmp") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(p)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + } + } + + test("aggregate push down - Filter alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1), MAX(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(7))) + } + } + } + + test("aggregate push down - alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-1, 0))) + } + } + } + + test("aggregate push down - aggregate over alias not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val df = spark.table("t") + val query = df.select($"_1".as("col1")).agg(min($"col1")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" // aggregate alias not pushed down + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(-2))) + } + } + } + + test("aggregate push down - query with group by not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is group by + val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3))) + } + } + } + + test("aggregate push down - query with filter not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // aggregate not pushed down if there is filter + val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(2))) + } + } + } + + test("aggregate push down - push down only if all the aggregates can be pushed down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + // not push down since sum can't be pushed down + val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2, 41))) + } + } + } + + test("aggregate push down - MIN/MAX/COUNT") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withParquetTable(data, "t") { + withSQLConf( + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") { + val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_3), " + + "MAX(_3), " + + "MIN(_1), " + + "MAX(_1), " + + "COUNT(*), " + + "COUNT(_1), " + + "COUNT(_2), " + + "COUNT(_3)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + } + } + + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + val schema = StructType(List(StructField("StringCol", StringType, true), + StructField("BooleanCol", BooleanType, false), + StructField("ByteCol", ByteType, false), + StructField("BinaryCol", BinaryType, false), + StructField("ShortCol", ShortType, false), + StructField("IntegerCol", IntegerType, true), + StructField("LongCol", LongType, false), + StructField("FloatCol", FloatType, false), + StructField("DoubleCol", DoubleType, false), + StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DateCol", DateType, false), + StructField("TimestampCol", TimestampType, false)).toArray) + + val rdd = sparkContext.parallelize(rows) + withTempPath { file => + spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + withTempView("test") { + spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + + val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMinWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment) + } + + checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) + + val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + + testMinWithOutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(StringCol), " + + "MIN(BooleanCol), " + + "MIN(ByteCol), " + + "MIN(BinaryCol), " + + "MIN(ShortCol), " + + "MIN(IntegerCol), " + + "MIN(LongCol), " + + "MIN(FloatCol), " + + "MIN(DoubleCol), " + + "MIN(DecimalCol), " + + "MIN(DateCol)]" + checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) + } + + checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + ("2004-06-19").date))) + + val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + testMaxWithTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + + val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + + testMaxWithoutTS.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(StringCol), " + + "MAX(BooleanCol), " + + "MAX(ByteCol), " + + "MAX(BinaryCol), " + + "MAX(ShortCol), " + + "MAX(IntegerCol), " + + "MAX(LongCol), " + + "MAX(FloatCol), " + + "MAX(DoubleCol), " + + "MAX(DecimalCol), " + + "MAX(DateCol)]" + checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) + } + + checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date))) + + val testCount = sql("SELECT count(*), count(StringCol), count(BooleanCol)," + + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + + " count(LongCol), count(FloatCol), count(DoubleCol)," + + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + + testCount.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(*), " + + "COUNT(StringCol), " + + "COUNT(BooleanCol), " + + "COUNT(ByteCol), " + + "COUNT(BinaryCol), " + + "COUNT(ShortCol), " + + "COUNT(IntegerCol), " + + "COUNT(LongCol), " + + "COUNT(FloatCol), " + + "COUNT(DoubleCol), " + + "COUNT(DecimalCol), " + + "COUNT(DateCol), " + + "COUNT(TimestampCol)]" + checkKeywordsExistsInExplain(testCount, expected_plan_fragment) + } + + checkAnswer(testCount, Seq(Row(3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + } + } + } + } + } +} + +class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "parquet") +} + +class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + // TODO: enable Parquet V2 write path after file source V2 writers are workable. + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_LIST, "") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala index 4977cbd5dd477..27115bff902a1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala @@ -206,7 +206,7 @@ class HiveParquetSourceSuite extends ParquetPartitioningTest { } } - // After pushing down aggregate to parquet, we can have something like MAX(C) in column name + // We can have something like MAX(C) in column name for aggregate push down // ignore this test for now ignore("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { withTempDir { tempDir => From 0cf218048b3ba2975d1af5ea8eb55fafa65c09e2 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 11 Aug 2021 19:43:42 -0700 Subject: [PATCH 04/18] address comments --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../datasources/PartitioningUtils.scala | 11 ++--- .../datasources/parquet/ParquetUtils.scala | 49 ++++++++++--------- .../ParquetPartitionReaderFactory.scala | 16 +++--- .../ParquetAggregatePushDownSuite.scala | 24 +++++++++ 5 files changed, 64 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 798429cff2fbf..2ac847f517e65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -855,7 +855,7 @@ object SQLConf { val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + - " down to parquet for optimization. ") + " down to Parquet for optimization. ") .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 184e63179d47b..f05cdda3f0812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -602,13 +602,10 @@ object PartitioningUtils { (fullSchema, overlappedPartCols.toMap) } - def getColName(f: StructField, caseSensitive: Boolean): String = { - if (caseSensitive) { - f.name - } else { - f.name.toLowerCase(Locale.ROOT) - } - } + def getColName(f: StructField, caseSensitive: Boolean): String = getColName(f.name, caseSensitive) + + def getColName(s: String, caseSensitive: Boolean): String = + if (caseSensitive) s else s.toLowerCase(Locale.ROOT) private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { if (caseSensitive) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 8cfa8f76ae87b..7d33e6f0e4682 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -147,10 +147,10 @@ object ParquetUtils { } /** - * When the partial Aggregates (Max/Min/Count) are pushed down to parquet, we don't need to - * createRowBaseReader to read data from parquet and aggregate at spark layer. Instead we want - * to get the partial Aggregates (Max/Min/Count) result using the statistics information - * from parquet footer file, and then construct an InternalRow from these Aggregate results. + * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to + * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the partial aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct an InternalRow from these aggregate results. * * @return Aggregate results in the format of InternalRow */ @@ -235,11 +235,11 @@ object ParquetUtils { } /** - * When the Aggregates (Max/Min/Count) are pushed down to parquet, in the case of + * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader - * to read data from parquet and aggregate at spark layer. Instead we want - * to get the Aggregates (Max/Min/Count) result using the statistics information - * from parquet footer file, and then construct a ColumnarBatch from these Aggregate results. + * to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results. * * @return Aggregate results in the format of ColumnarBatch */ @@ -301,11 +301,11 @@ object ParquetUtils { } /** - * Calculate the pushed down Aggregates (Max/Min/Count) result using the statistics - * information from parquet footer file. + * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics + * information from Parquet footer file. * * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and Array[Any]. - * The first element is the PrimitiveTypeName of the Aggregate column, + * The first element is the PrimitiveTypeName of the aggregate column, * and the second element is the aggregated value. */ private[sql] def getPushedDownAggResult( @@ -330,29 +330,34 @@ object ParquetUtils { val blockMetaData = block.getColumns() agg match { case max: Max => - index = dataSchema.fieldNames.toList.indexOf(max.column.fieldNames.head) + val colName = PartitioningUtils.getColName(max.column.fieldNames.head, isCaseSensitive) + index = dataSchema.fields.map(PartitioningUtils + .getColName(_, isCaseSensitive)).toList.indexOf(colName) val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) - if (currentMax != None && - (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0)) { + if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } case min: Min => - index = dataSchema.fieldNames.toList.indexOf(min.column.fieldNames.head) + val colName = PartitioningUtils.getColName(min.column.fieldNames.head, isCaseSensitive) + index = dataSchema.fields.map(PartitioningUtils + .getColName(_, isCaseSensitive)).toList.indexOf(colName) val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) - if (currentMin != None && - (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0)) { + if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin } case count: Count => rowCount += block.getRowCount - var isPartitionCol = false; + var isPartitionCol = false + val colName = + PartitioningUtils.getColName(count.column.fieldNames.head, isCaseSensitive) if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) - .toSet.contains(count.column().fieldNames.head)) { + .toSet.contains(colName)) { isPartitionCol = true } isCount = true if(!isPartitionCol) { - index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) + index = dataSchema.fields.map(PartitioningUtils + .getColName(_, isCaseSensitive)).toList.indexOf(colName) // Count(*) includes the null values, but Count (colName) doesn't. rowCount -= getNumNulls(blockMetaData, index) } @@ -384,7 +389,7 @@ object ParquetUtils { isMax: Boolean): Any = { val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.hasNonNullValue) { - throw new UnsupportedOperationException("No min/max found for parquet file, Set SQLConf" + + throw new UnsupportedOperationException("No min/max found for Parquet file, Set SQLConf" + s" ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } else { if (isMax) statistics.genericGetMax() else statistics.genericGetMin() @@ -396,7 +401,7 @@ object ParquetUtils { i: Int): Long = { val statistics = columnChunkMetaData.get(i).getStatistics() if (!statistics.isNumNullsSet()) { - throw new UnsupportedOperationException("Number of nulls not set for parquet file." + + throw new UnsupportedOperationException("Number of nulls not set for Parquet file." + s" Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } statistics.getNumNulls(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 0a191a710026a..00b97fb5a061b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -118,18 +118,18 @@ case class ParquetPartitionReaderFactory( } } else { new PartitionReader[InternalRow] { - var count = 0 + var hasNext = true - override def next(): Boolean = if (count == 0) true else false + override def next(): Boolean = hasNext override def get(): InternalRow = { - count += 1 + hasNext = false val footer = getFooter(file) ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, aggregation.get, readDataSchema, datetimeRebaseModeInRead, isCaseSensitive) } - override def close(): Unit = return + override def close(): Unit = {} } } @@ -152,19 +152,19 @@ case class ParquetPartitionReaderFactory( } } else { new PartitionReader[ColumnarBatch] { - var count = 0 + var hasNext = true - override def next(): Boolean = if (count == 0) true else false + override def next(): Boolean = hasNext override def get(): ColumnarBatch = { - count += 1 + hasNext = false val footer = getFooter(file) ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead, isCaseSensitive) } - override def close(): Unit = return + override def close(): Unit = {} } } fileReader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala index 2e1b5c97659a3..68fe85baedc88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -475,6 +475,30 @@ abstract class ParquetAggregatePushDownSuite } } } + + test("aggregate push down - column name case sensitivity") { + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").parquet(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp"); + val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(iD), MIN(Id)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(9, 0))) + } + } + } + } + } } class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { From d3afeb14d2340f1b3fee0033ab56799c0252723a Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 15 Aug 2021 16:43:17 -0700 Subject: [PATCH 05/18] fix case sensitivity and use RowToColumnConverter for ColumnarBatch path --- .../datasources/PartitioningUtils.scala | 11 +-- .../datasources/parquet/ParquetUtils.scala | 70 +++++-------------- .../ParquetAggregatePushDownSuite.scala | 2 +- 3 files changed, 25 insertions(+), 58 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f05cdda3f0812..184e63179d47b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -602,10 +602,13 @@ object PartitioningUtils { (fullSchema, overlappedPartCols.toMap) } - def getColName(f: StructField, caseSensitive: Boolean): String = getColName(f.name, caseSensitive) - - def getColName(s: String, caseSensitive: Boolean): String = - if (caseSensitive) s else s.toLowerCase(Locale.ROOT) + def getColName(f: StructField, caseSensitive: Boolean): String = { + if (caseSensitive) { + f.name + } else { + f.name.toLowerCase(Locale.ROOT) + } + } private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = { if (caseSensitive) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 7d33e6f0e4682..b54c6cfa663c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitioningUtils} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED @@ -252,51 +253,21 @@ object ParquetUtils { offHeap: Boolean, datetimeRebaseModeInRead: String, isCaseSensitive: Boolean): ColumnarBatch = { - val (parquetTypes, values) = - getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) - val capacity = 4 * 1024 - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( - footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + val row = createAggInternalRowFromFooter( + footer, + dataSchema, + partitionSchema, + aggregation, + aggSchema, + datetimeRebaseModeInRead, + isCaseSensitive) + val converter = new RowToColumnConverter(aggSchema) val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(capacity, aggSchema) + OffHeapColumnVector.allocateColumns(4 * 1024, aggSchema) } else { - OnHeapColumnVector.allocateColumns(capacity, aggSchema) - } - - parquetTypes.zipWithIndex.foreach { - case (PrimitiveType.PrimitiveTypeName.INT32, i) => - aggSchema.fields(i).dataType match { - case ByteType => - columnVectors(i).appendByte(values(i).asInstanceOf[Integer].toByte) - case ShortType => - columnVectors(i).appendShort(values(i).asInstanceOf[Integer].toShort) - case IntegerType => - columnVectors(i).appendInt(values(i).asInstanceOf[Integer]) - case DateType => - val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Parquet") - columnVectors(i).appendInt(dateRebaseFunc(values(i).asInstanceOf[Integer])) - case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + - s" for INT32") - } - case (PrimitiveType.PrimitiveTypeName.INT64, i) => - columnVectors(i).appendLong(values(i).asInstanceOf[Long]) - case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => - columnVectors(i).appendFloat(values(i).asInstanceOf[Float]) - case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => - columnVectors(i).appendDouble(values(i).asInstanceOf[Double]) - case (PrimitiveType.PrimitiveTypeName.BINARY, i) => - val bytes = values(i).asInstanceOf[Binary].getBytes - columnVectors(i).putByteArray(0, bytes, 0, bytes.length) - case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => - val bytes = values(i).asInstanceOf[Binary].getBytes - columnVectors(i).putByteArray(0, bytes, 0, bytes.length) - case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => - columnVectors(i).appendBoolean(values(i).asInstanceOf[Boolean]) - case _ => - throw new SparkException("Unexpected parquet type name") + OnHeapColumnVector.allocateColumns(4 * 1024, aggSchema) } + converter.convert(row, columnVectors.toArray) new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) } @@ -330,17 +301,13 @@ object ParquetUtils { val blockMetaData = block.getColumns() agg match { case max: Max => - val colName = PartitioningUtils.getColName(max.column.fieldNames.head, isCaseSensitive) - index = dataSchema.fields.map(PartitioningUtils - .getColName(_, isCaseSensitive)).toList.indexOf(colName) + index = dataSchema.fieldNames.toList.indexOf(max.column.fieldNames.head) val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } case min: Min => - val colName = PartitioningUtils.getColName(min.column.fieldNames.head, isCaseSensitive) - index = dataSchema.fields.map(PartitioningUtils - .getColName(_, isCaseSensitive)).toList.indexOf(colName) + index = dataSchema.fieldNames.toList.indexOf(min.column.fieldNames.head) val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin @@ -348,16 +315,13 @@ object ParquetUtils { case count: Count => rowCount += block.getRowCount var isPartitionCol = false - val colName = - PartitioningUtils.getColName(count.column.fieldNames.head, isCaseSensitive) if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) - .toSet.contains(colName)) { + .toSet.contains(count.column().fieldNames.head)) { isPartitionCol = true } isCount = true if(!isPartitionCol) { - index = dataSchema.fields.map(PartitioningUtils - .getColName(_, isCaseSensitive)).toList.indexOf(colName) + index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) // Count(*) includes the null values, but Count (colName) doesn't. rowCount -= getNumNulls(blockMetaData, index) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala index 68fe85baedc88..bddef8525ebdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -490,7 +490,7 @@ abstract class ParquetAggregatePushDownSuite selectAgg.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [MAX(iD), MIN(Id)]" + "PushedAggregation: [MAX(id), MIN(id)]" checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) } checkAnswer(selectAgg, Seq(Row(9, 0))) From 8deeaaf4bb60b346d0bced0c97221e87e05c6aae Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 27 Aug 2021 16:19:13 -0700 Subject: [PATCH 06/18] use ParquetRowConverter to build data row --- .../datasources/parquet/ParquetUtils.scala | 126 ++++++++---------- .../ParquetAggregatePushDownSuite.scala | 63 ++++++--- 2 files changed, 99 insertions(+), 90 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b54c6cfa663c5..a1c4d7b2c6562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.parquet -import java.math.{BigDecimal, BigInteger} import java.util import scala.collection.mutable.ArrayBuilder @@ -26,20 +25,19 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.PrimitiveType +import org.apache.parquet.schema.{PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} -import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED -import org.apache.spark.sql.types.{BinaryType, ByteType, DateType, Decimal, DecimalType, IntegerType, LongType, ShortType, StringType, StructType} +import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} +import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} -import org.apache.spark.unsafe.types.UTF8String object ParquetUtils { def inferSchema( @@ -163,76 +161,42 @@ object ParquetUtils { aggSchema: StructType, datetimeRebaseModeInRead: String, isCaseSensitive: Boolean): InternalRow = { - val (parquetTypes, values) = + val (primitiveTypeName, primitiveType, values) = getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) - val mutableRow = new SpecificInternalRow(aggSchema.fields.map(x => x.dataType)) - val footerFileMetaData = footer.getFileMetaData - val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( - footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) - parquetTypes.zipWithIndex.foreach { + val builder = Types.buildMessage() + primitiveType.foreach(t => builder.addField(t)) + val parquetSchema = builder.named("root") + + val schemaConverter = new ParquetToSparkSchemaConverter + val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, + None, LegacyBehaviorPolicy.CORRECTED, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) + primitiveTypeName.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + val v = values(i).asInstanceOf[Boolean] + converter.getConverter(i).asPrimitiveConverter().addBoolean(v) case (PrimitiveType.PrimitiveTypeName.INT32, i) => - aggSchema.fields(i).dataType match { - case ByteType => - mutableRow.setByte(i, values(i).asInstanceOf[Integer].toByte) - case ShortType => - mutableRow.setShort(i, values(i).asInstanceOf[Integer].toShort) - case IntegerType => - mutableRow.setInt(i, values(i).asInstanceOf[Integer]) - case DateType => - val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Parquet") - mutableRow.update(i, dateRebaseFunc(values(i).asInstanceOf[Integer])) - case d: DecimalType => - val decimal = Decimal(values(i).asInstanceOf[Integer].toLong, d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + - " for INT32") - } + val v = values(i).asInstanceOf[Integer] + converter.getConverter(i).asPrimitiveConverter().addInt(v) case (PrimitiveType.PrimitiveTypeName.INT64, i) => - aggSchema.fields(i).dataType match { - case LongType => - mutableRow.setLong(i, values(i).asInstanceOf[Long]) - case d: DecimalType => - val decimal = Decimal(values(i).asInstanceOf[Long], d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + - " for INT64") - } + val v = values(i).asInstanceOf[Long] + converter.getConverter(i).asPrimitiveConverter().addLong(v) case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => - mutableRow.setFloat(i, values(i).asInstanceOf[Float]) + val v = values(i).asInstanceOf[Float] + converter.getConverter(i).asPrimitiveConverter().addFloat(v) case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => - mutableRow.setDouble(i, values(i).asInstanceOf[Double]) - case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => - mutableRow.setBoolean(i, values(i).asInstanceOf[Boolean]) + val v = values(i).asInstanceOf[Double] + converter.getConverter(i).asPrimitiveConverter().addDouble(v) case (PrimitiveType.PrimitiveTypeName.BINARY, i) => - val bytes = values(i).asInstanceOf[Binary].getBytes - aggSchema.fields(i).dataType match { - case StringType => - mutableRow.update(i, UTF8String.fromBytes(bytes)) - case BinaryType => - mutableRow.update(i, bytes) - case d: DecimalType => - val decimal = - Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + - " for Binary") - } + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter().addBinary(v) case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => - val bytes = values(i).asInstanceOf[Binary].getBytes - aggSchema.fields(i).dataType match { - case d: DecimalType => - val decimal = - Decimal(new BigDecimal(new BigInteger(bytes), d.scale), d.precision, d.scale) - mutableRow.setDecimal(i, decimal, d.precision) - case _ => throw new SparkException(s"Unexpected type ${aggSchema.fields(i).dataType}" + - " for FIXED_LEN_BYTE_ARRAY") - } + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter().addBinary(v) case _ => throw new SparkException("Unexpected parquet type name") } - mutableRow + converter.currentRecord } /** @@ -285,11 +249,12 @@ object ParquetUtils { partitionSchema: StructType, aggregation: Aggregation, isCaseSensitive: Boolean) - : (Array[PrimitiveType.PrimitiveTypeName], Array[Any]) = { + : (Array[PrimitiveType.PrimitiveTypeName], Array[PrimitiveType], Array[Any]) = { val footerFileMetaData = footer.getFileMetaData val fields = footerFileMetaData.getSchema.getFields val blocks = footer.getBlocks() - val typesBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] + val primitiveTypeBuilder = ArrayBuilder.make[PrimitiveType] + val primitiveTypeNameBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] val valuesBuilder = ArrayBuilder.make[Any] aggregation.aggregateExpressions().foreach { agg => @@ -297,22 +262,26 @@ object ParquetUtils { var rowCount = 0L var isCount = false var index = 0 + var schemaName = "" blocks.forEach { block => val blockMetaData = block.getColumns() agg match { case max: Max => index = dataSchema.fieldNames.toList.indexOf(max.column.fieldNames.head) + schemaName = "max(" + max.column.fieldNames.head + ")" val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } case min: Min => index = dataSchema.fieldNames.toList.indexOf(min.column.fieldNames.head) + schemaName = "min(" + min.column.fieldNames.head + ")" val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin } case count: Count => + schemaName = "count(" + count.column.fieldNames.head + ")" rowCount += block.getRowCount var isPartitionCol = false if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) @@ -326,6 +295,7 @@ object ParquetUtils { rowCount -= getNumNulls(blockMetaData, index) } case _: CountStar => + schemaName = "count(*)" rowCount += block.getRowCount isCount = true case _ => @@ -333,13 +303,27 @@ object ParquetUtils { } if (isCount) { valuesBuilder += rowCount - typesBuilder += PrimitiveType.PrimitiveTypeName.INT64 + primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName); + primitiveTypeNameBuilder += PrimitiveType.PrimitiveTypeName.INT64 } else { valuesBuilder += value - typesBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName + if (fields.get(index).asPrimitiveType().getPrimitiveTypeName + .equals(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)) { // for decimal type + val decimal = dataSchema.fields(index).dataType.asInstanceOf[DecimalType] + val precision = decimal.precision + val scale = decimal.scale + val length = precision + scale + primitiveTypeBuilder += Types.required(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) + // note: .precision and .scale are deprecated + .length(length).precision(precision).scale(scale).named(schemaName) + } else { + primitiveTypeBuilder += + Types.required(fields.get(index).asPrimitiveType.getPrimitiveTypeName).named(schemaName) + } + primitiveTypeNameBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName } } - (typesBuilder.result(), valuesBuilder.result()) + (primitiveTypeNameBuilder.result(), primitiveTypeBuilder.result(), valuesBuilder.result()) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala index bddef8525ebdf..9c65f7d829d2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -307,7 +307,7 @@ abstract class ParquetAggregatePushDownSuite Long.MaxValue, 0.15.toFloat, 0.75D, - Decimal("12.345678"), + // Decimal("12.345678"), ("2021-01-01").date, ("2015-01-01 23:50:59.123").ts), Row( @@ -320,7 +320,7 @@ abstract class ParquetAggregatePushDownSuite Long.MinValue, 0.25.toFloat, 0.85D, - Decimal("1.2345678"), + // Decimal("1.2345678"), ("2015-01-01").date, ("2021-01-01 23:50:59.123").ts), Row( @@ -333,7 +333,7 @@ abstract class ParquetAggregatePushDownSuite 11111111L, 0.25.toFloat, 0.75D, - Decimal("12345.678"), + // Decimal("12345.678"), ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts) ) @@ -347,7 +347,7 @@ abstract class ParquetAggregatePushDownSuite StructField("LongCol", LongType, false), StructField("FloatCol", FloatType, false), StructField("DoubleCol", DoubleType, false), - StructField("DecimalCol", DecimalType(25, 5), true), + // StructField("DecimalCol", DecimalType(25, 5), true), StructField("DateCol", DateType, false), StructField("TimestampCol", TimestampType, false)).toArray) @@ -363,7 +363,7 @@ abstract class ParquetAggregatePushDownSuite val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + - "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + "min(DoubleCol), min(DateCol), min(TimestampCol) FROM test") // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down @@ -375,12 +375,12 @@ abstract class ParquetAggregatePushDownSuite } checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + - "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") + "min(DoubleCol), min(DateCol) FROM test") testMinWithOutTS.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -394,18 +394,18 @@ abstract class ParquetAggregatePushDownSuite "MIN(LongCol), " + "MIN(FloatCol), " + "MIN(DoubleCol), " + - "MIN(DecimalCol), " + + // "MIN(DecimalCol), " + "MIN(DateCol)]" checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) } checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date))) val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + - "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") + "max(DoubleCol), max(DateCol), max(TimestampCol) FROM test") // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down @@ -418,11 +418,11 @@ abstract class ParquetAggregatePushDownSuite checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + - "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") + "max(DoubleCol), max(DateCol) FROM test") testMaxWithoutTS.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -436,24 +436,26 @@ abstract class ParquetAggregatePushDownSuite "MAX(LongCol), " + "MAX(FloatCol), " + "MAX(DoubleCol), " + - "MAX(DecimalCol), " + + // "MAX(DecimalCol), " + "MAX(DateCol)]" checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) } checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - 12345.678, ("2021-01-01").date))) + ("2021-01-01").date))) - val testCount = sql("SELECT count(*), count(StringCol), count(BooleanCol)," + + val testCountStar = sql("SELECT count(*) FROM test") + + val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + " count(LongCol), count(FloatCol), count(DoubleCol)," + - " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + " count(DateCol), count(TimestampCol) FROM test") testCount.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = - "PushedAggregation: [COUNT(*), " + + "PushedAggregation: [" + "COUNT(StringCol), " + "COUNT(BooleanCol), " + "COUNT(ByteCol), " + @@ -463,13 +465,36 @@ abstract class ParquetAggregatePushDownSuite "COUNT(LongCol), " + "COUNT(FloatCol), " + "COUNT(DoubleCol), " + - "COUNT(DecimalCol), " + "COUNT(DateCol), " + "COUNT(TimestampCol)]" checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } - checkAnswer(testCount, Seq(Row(3, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) + checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3))) + } + } + } + } + } + + ignore("decimal test") { + val rows = + Seq(Row(Decimal("12.345678")), Row(Decimal("1.2345678")), Row(Decimal("12345.678"))) + + val schema = StructType(List(StructField("DecimalCol", DecimalType(25, 5), true)).toArray) + + val rdd = sparkContext.parallelize(rows) + withTempPath { file => + spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) + withTempView("test") { + spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + val test = sql("SELECT min(DecimalCol) FROM test") + test.show(false) + test.explain(true) } } } From 4e3c69ab4a04462886c04cb889e4f00a4c3a609f Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 27 Aug 2021 17:19:40 -0700 Subject: [PATCH 07/18] rebase and update to the latest --- .../spark/sql/execution/datasources/parquet/ParquetUtils.scala | 2 +- .../datasources/v2/parquet/ParquetPartitionReaderFactory.scala | 2 +- .../sql/execution/datasources/v2/parquet/ParquetScan.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index a1c4d7b2c6562..46c147fff257a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -31,7 +31,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.{Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 00b97fb5a061b..a6e879643b766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -34,7 +34,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.connector.expressions.Aggregation +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 0c41e013676b9..6a02125038a66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,7 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.expressions.Aggregation +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} From c5f8c491bab764033c4d823789a80baa06d0743e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sat, 28 Aug 2021 22:02:03 -0700 Subject: [PATCH 08/18] fix decimal converter problem --- .../datasources/parquet/ParquetUtils.scala | 35 +++++------- .../ParquetPartitionReaderFactory.scala | 2 +- .../ParquetAggregatePushDownSuite.scala | 57 ++++++------------- 3 files changed, 33 insertions(+), 61 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 46c147fff257a..b4ff9b46385f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{PrimitiveType, Types} +import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkException @@ -159,9 +159,8 @@ object ParquetUtils { partitionSchema: StructType, aggregation: Aggregation, aggSchema: StructType, - datetimeRebaseModeInRead: String, isCaseSensitive: Boolean): InternalRow = { - val (primitiveTypeName, primitiveType, values) = + val (primitiveType, values) = getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) val builder = Types.buildMessage() @@ -171,6 +170,7 @@ object ParquetUtils { val schemaConverter = new ParquetToSparkSchemaConverter val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, None, LegacyBehaviorPolicy.CORRECTED, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) + val primitiveTypeName = primitiveType.map(_.getPrimitiveTypeName) primitiveTypeName.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => val v = values(i).asInstanceOf[Boolean] @@ -194,7 +194,7 @@ object ParquetUtils { val v = values(i).asInstanceOf[Binary] converter.getConverter(i).asPrimitiveConverter().addBinary(v) case _ => - throw new SparkException("Unexpected parquet type name") + throw new SparkException("Unexpected parquet type name: " + primitiveTypeName) } converter.currentRecord } @@ -223,7 +223,6 @@ object ParquetUtils { partitionSchema, aggregation, aggSchema, - datetimeRebaseModeInRead, isCaseSensitive) val converter = new RowToColumnConverter(aggSchema) val columnVectors = if (offHeap) { @@ -249,12 +248,11 @@ object ParquetUtils { partitionSchema: StructType, aggregation: Aggregation, isCaseSensitive: Boolean) - : (Array[PrimitiveType.PrimitiveTypeName], Array[PrimitiveType], Array[Any]) = { + : (Array[PrimitiveType], Array[Any]) = { val footerFileMetaData = footer.getFileMetaData val fields = footerFileMetaData.getSchema.getFields val blocks = footer.getBlocks() val primitiveTypeBuilder = ArrayBuilder.make[PrimitiveType] - val primitiveTypeNameBuilder = ArrayBuilder.make[PrimitiveType.PrimitiveTypeName] val valuesBuilder = ArrayBuilder.make[Any] aggregation.aggregateExpressions().foreach { agg => @@ -267,15 +265,17 @@ object ParquetUtils { val blockMetaData = block.getColumns() agg match { case max: Max => - index = dataSchema.fieldNames.toList.indexOf(max.column.fieldNames.head) - schemaName = "max(" + max.column.fieldNames.head + ")" + val colName = max.column.fieldNames.head + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "max(" + colName + ")" val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } case min: Min => - index = dataSchema.fieldNames.toList.indexOf(min.column.fieldNames.head) - schemaName = "min(" + min.column.fieldNames.head + ")" + val colName = min.column.fieldNames.head + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "min(" + colName + ")" val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin @@ -304,26 +304,21 @@ object ParquetUtils { if (isCount) { valuesBuilder += rowCount primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName); - primitiveTypeNameBuilder += PrimitiveType.PrimitiveTypeName.INT64 } else { valuesBuilder += value if (fields.get(index).asPrimitiveType().getPrimitiveTypeName .equals(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)) { // for decimal type val decimal = dataSchema.fields(index).dataType.asInstanceOf[DecimalType] - val precision = decimal.precision - val scale = decimal.scale - val length = precision + scale - primitiveTypeBuilder += Types.required(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) - // note: .precision and .scale are deprecated - .length(length).precision(precision).scale(scale).named(schemaName) + primitiveTypeBuilder += Types.required(PrimitiveTypeName.BINARY) + .as(LogicalTypeAnnotation.decimalType(decimal.scale, decimal.precision)) + .named(schemaName) } else { primitiveTypeBuilder += Types.required(fields.get(index).asPrimitiveType.getPrimitiveTypeName).named(schemaName) } - primitiveTypeNameBuilder += fields.get(index).asPrimitiveType.getPrimitiveTypeName } } - (primitiveTypeNameBuilder.result(), primitiveTypeBuilder.result(), valuesBuilder.result()) + (primitiveTypeBuilder.result(), valuesBuilder.result()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index a6e879643b766..87a89f8f0bf81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -126,7 +126,7 @@ case class ParquetPartitionReaderFactory( hasNext = false val footer = getFooter(file) ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, datetimeRebaseModeInRead, isCaseSensitive) + aggregation.get, readDataSchema, isCaseSensitive) } override def close(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala index 9c65f7d829d2c..e06ca2ae2ef23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -307,7 +307,7 @@ abstract class ParquetAggregatePushDownSuite Long.MaxValue, 0.15.toFloat, 0.75D, - // Decimal("12.345678"), + Decimal("12.345678"), ("2021-01-01").date, ("2015-01-01 23:50:59.123").ts), Row( @@ -320,7 +320,7 @@ abstract class ParquetAggregatePushDownSuite Long.MinValue, 0.25.toFloat, 0.85D, - // Decimal("1.2345678"), + Decimal("1.2345678"), ("2015-01-01").date, ("2021-01-01 23:50:59.123").ts), Row( @@ -333,7 +333,7 @@ abstract class ParquetAggregatePushDownSuite 11111111L, 0.25.toFloat, 0.75D, - // Decimal("12345.678"), + Decimal("12345.678"), ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts) ) @@ -347,7 +347,7 @@ abstract class ParquetAggregatePushDownSuite StructField("LongCol", LongType, false), StructField("FloatCol", FloatType, false), StructField("DoubleCol", DoubleType, false), - // StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DecimalCol", DecimalType(25, 5), true), StructField("DateCol", DateType, false), StructField("TimestampCol", TimestampType, false)).toArray) @@ -363,7 +363,7 @@ abstract class ParquetAggregatePushDownSuite val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + - "min(DoubleCol), min(DateCol), min(TimestampCol) FROM test") + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down @@ -375,12 +375,12 @@ abstract class ParquetAggregatePushDownSuite } checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts))) val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + - "min(DoubleCol), min(DateCol) FROM test") + "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test") testMinWithOutTS.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -394,18 +394,18 @@ abstract class ParquetAggregatePushDownSuite "MIN(LongCol), " + "MIN(FloatCol), " + "MIN(DoubleCol), " + - // "MIN(DecimalCol), " + + "MIN(DecimalCol), " + "MIN(DateCol)]" checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment) } checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes, - 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457, ("2004-06-19").date))) val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + - "max(DoubleCol), max(DateCol), max(TimestampCol) FROM test") + "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test") // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type // so aggregates are not pushed down @@ -418,11 +418,11 @@ abstract class ParquetAggregatePushDownSuite checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte, "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts))) val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " + "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + - "max(DoubleCol), max(DateCol) FROM test") + "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test") testMaxWithoutTS.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -436,21 +436,21 @@ abstract class ParquetAggregatePushDownSuite "MAX(LongCol), " + "MAX(FloatCol), " + "MAX(DoubleCol), " + - // "MAX(DecimalCol), " + + "MAX(DecimalCol), " + "MAX(DateCol)]" checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment) } checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte, "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, - ("2021-01-01").date))) + 12345.678, ("2021-01-01").date))) val testCountStar = sql("SELECT count(*) FROM test") val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + " count(LongCol), count(FloatCol), count(DoubleCol)," + - " count(DateCol), count(TimestampCol) FROM test") + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") testCount.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => @@ -465,36 +465,13 @@ abstract class ParquetAggregatePushDownSuite "COUNT(LongCol), " + "COUNT(FloatCol), " + "COUNT(DoubleCol), " + + "COUNT(DecimalCol), " + "COUNT(DateCol), " + "COUNT(TimestampCol)]" checkKeywordsExistsInExplain(testCount, expected_plan_fragment) } - checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3))) - } - } - } - } - } - - ignore("decimal test") { - val rows = - Seq(Row(Decimal("12.345678")), Row(Decimal("1.2345678")), Row(Decimal("12345.678"))) - - val schema = StructType(List(StructField("DecimalCol", DecimalType(25, 5), true)).toArray) - - val rdd = sparkContext.parallelize(rows) - withTempPath { file => - spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath) - withTempView("test") { - spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test") - val enableVectorizedReader = Seq("false", "true") - for (testVectorizedReader <- enableVectorizedReader) { - withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true", - vectorizedReaderEnabledKey -> testVectorizedReader) { - val test = sql("SELECT min(DecimalCol) FROM test") - test.show(false) - test.explain(true) + checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))) } } } From eafbe1fcae8fe8888950890ecce5f6f3493ed7fd Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 31 Aug 2021 15:06:05 -0700 Subject: [PATCH 09/18] address comments --- .../datasources/parquet/ParquetUtils.scala | 43 ++++++++----------- .../ParquetPartitionReaderFactory.scala | 3 +- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b4ff9b46385f3..9adeb50f4233e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util -import scala.collection.mutable.ArrayBuilder +import scala.collection.mutable import scala.language.existentials import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Types} +import org.apache.parquet.schema.{PrimitiveType, Types} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.SparkException @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.RowToColumnConverter import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} -import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} object ParquetUtils { @@ -193,8 +193,8 @@ object ParquetUtils { case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => val v = values(i).asInstanceOf[Binary] converter.getConverter(i).asPrimitiveConverter().addBinary(v) - case _ => - throw new SparkException("Unexpected parquet type name: " + primitiveTypeName) + case (_, i) => + throw new SparkException("Unexpected parquet type name: " + primitiveTypeName(i)) } converter.currentRecord } @@ -214,8 +214,8 @@ object ParquetUtils { partitionSchema: StructType, aggregation: Aggregation, aggSchema: StructType, + columnBatchSize: Int, offHeap: Boolean, - datetimeRebaseModeInRead: String, isCaseSensitive: Boolean): ColumnarBatch = { val row = createAggInternalRowFromFooter( footer, @@ -226,9 +226,9 @@ object ParquetUtils { isCaseSensitive) val converter = new RowToColumnConverter(aggSchema) val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(4 * 1024, aggSchema) + OffHeapColumnVector.allocateColumns(columnBatchSize, aggSchema) } else { - OnHeapColumnVector.allocateColumns(4 * 1024, aggSchema) + OnHeapColumnVector.allocateColumns(columnBatchSize, aggSchema) } converter.convert(row, columnVectors.toArray) new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) @@ -238,8 +238,8 @@ object ParquetUtils { * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics * information from Parquet footer file. * - * @return A tuple of `Array[PrimitiveType.PrimitiveTypeName]` and Array[Any]. - * The first element is the PrimitiveTypeName of the aggregate column, + * @return A tuple of `Array[PrimitiveType]` and Array[Any]. + * The first element is the Parquet PrimitiveType of the aggregate column, * and the second element is the aggregated value. */ private[sql] def getPushedDownAggResult( @@ -251,9 +251,9 @@ object ParquetUtils { : (Array[PrimitiveType], Array[Any]) = { val footerFileMetaData = footer.getFileMetaData val fields = footerFileMetaData.getSchema.getFields - val blocks = footer.getBlocks() - val primitiveTypeBuilder = ArrayBuilder.make[PrimitiveType] - val valuesBuilder = ArrayBuilder.make[Any] + val blocks = footer.getBlocks + val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] + val valuesBuilder = mutable.ArrayBuilder.make[Any] aggregation.aggregateExpressions().foreach { agg => var value: Any = None @@ -262,7 +262,7 @@ object ParquetUtils { var index = 0 var schemaName = "" blocks.forEach { block => - val blockMetaData = block.getColumns() + val blockMetaData = block.getColumns agg match { case max: Max => val colName = max.column.fieldNames.head @@ -306,16 +306,11 @@ object ParquetUtils { primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName); } else { valuesBuilder += value - if (fields.get(index).asPrimitiveType().getPrimitiveTypeName - .equals(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY)) { // for decimal type - val decimal = dataSchema.fields(index).dataType.asInstanceOf[DecimalType] - primitiveTypeBuilder += Types.required(PrimitiveTypeName.BINARY) - .as(LogicalTypeAnnotation.decimalType(decimal.scale, decimal.precision)) - .named(schemaName) - } else { - primitiveTypeBuilder += - Types.required(fields.get(index).asPrimitiveType.getPrimitiveTypeName).named(schemaName) - } + val field = fields.get(index) + primitiveTypeBuilder += Types.required(field.asPrimitiveType().getPrimitiveTypeName) + .as(field.getLogicalTypeAnnotation) + .length(field.asPrimitiveType().getTypeLength) + .named(schemaName) } } (primitiveTypeBuilder.result(), valuesBuilder.result()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 87a89f8f0bf81..b829e4cca2f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -83,6 +83,7 @@ case class ParquetPartitionReaderFactory( private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private val columnBatchSize = sqlConf.columnBatchSize private def getFooter(file: PartitionedFile): ParquetMetadata = { val conf = broadcastedConf.value.value @@ -160,7 +161,7 @@ case class ParquetPartitionReaderFactory( hasNext = false val footer = getFooter(file) ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, enableOffHeapColumnVector, datetimeRebaseModeInRead, + aggregation.get, readDataSchema, columnBatchSize, enableOffHeapColumnVector, isCaseSensitive) } From 4417fc473d4d83e1d32001a85f9e0bf17261298b Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 31 Aug 2021 22:03:35 -0700 Subject: [PATCH 10/18] pass datetimeRebaseMode --- .../datasources/parquet/ParquetUtils.scala | 5 ++++- .../ParquetPartitionReaderFactory.scala | 18 ++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 9adeb50f4233e..aa1b4433c725f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -159,6 +159,7 @@ object ParquetUtils { partitionSchema: StructType, aggregation: Aggregation, aggSchema: StructType, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, isCaseSensitive: Boolean): InternalRow = { val (primitiveType, values) = getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) @@ -169,7 +170,7 @@ object ParquetUtils { val schemaConverter = new ParquetToSparkSchemaConverter val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, - None, LegacyBehaviorPolicy.CORRECTED, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) + None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) val primitiveTypeName = primitiveType.map(_.getPrimitiveTypeName) primitiveTypeName.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => @@ -216,6 +217,7 @@ object ParquetUtils { aggSchema: StructType, columnBatchSize: Int, offHeap: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, isCaseSensitive: Boolean): ColumnarBatch = { val row = createAggInternalRowFromFooter( footer, @@ -223,6 +225,7 @@ object ParquetUtils { partitionSchema, aggregation, aggSchema, + datetimeRebaseMode, isCaseSensitive) val converter = new RowToColumnConverter(aggSchema) val columnVectors = if (offHeap) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index b829e4cca2f30..4f2d519d228f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -27,7 +27,7 @@ import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} -import org.apache.parquet.hadoop.metadata.ParquetMetadata +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast @@ -96,6 +96,13 @@ case class ParquetPartitionReaderFactory( } } + private def getDatetimeRebaseMode( + footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = { + DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -127,7 +134,8 @@ case class ParquetPartitionReaderFactory( hasNext = false val footer = getFooter(file) ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, isCaseSensitive) + aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), + isCaseSensitive) } override def close(): Unit = {} @@ -162,7 +170,7 @@ case class ParquetPartitionReaderFactory( val footer = getFooter(file) ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, aggregation.get, readDataSchema, columnBatchSize, enableOffHeapColumnVector, - isCaseSensitive) + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) } override def close(): Unit = {} @@ -184,9 +192,7 @@ case class ParquetPartitionReaderFactory( val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) lazy val footerFileMetaData = getFooter(file).getFileMetaData - val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) + val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData) // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema From 8b515ed77a6cbf1c449f7ca4fb7a3b4624533695 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 1 Sep 2021 14:00:47 -0700 Subject: [PATCH 11/18] address comments --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 3 ++- .../datasources/parquet/ParquetSchemaConverter.scala | 2 +- .../sql/execution/datasources/parquet/ParquetUtils.scala | 5 ++--- .../v2/parquet/ParquetPartitionReaderFactory.scala | 6 +++--- .../datasources/parquet/ParquetAggregatePushDownSuite.scala | 1 - .../org/apache/spark/sql/hive/HiveParquetSourceSuite.scala | 2 +- 6 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2ac847f517e65..2f74815e8950d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -855,7 +855,8 @@ object SQLConf { val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + - " down to Parquet for optimization. ") + " down to Parquet for optimization. MAX/MIN/COUNT for Complex type and Timestamp" + + "can't be pushed down") .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 436555d921ee6..a2f60351e57c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -586,7 +586,7 @@ private[sql] object ParquetSchemaConverter { def checkFieldName(name: String): Unit = { // ,;{}\n\t= and space are special characters in Parquet schema - if (name.matches(".*[ ,;{}\n\t=].*")) { + if (name.matches(".*[ ,;{}()\n\t=].*")) { throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index aa1b4433c725f..420936e3a63e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -215,7 +215,6 @@ object ParquetUtils { partitionSchema: StructType, aggregation: Aggregation, aggSchema: StructType, - columnBatchSize: Int, offHeap: Boolean, datetimeRebaseMode: LegacyBehaviorPolicy.Value, isCaseSensitive: Boolean): ColumnarBatch = { @@ -229,9 +228,9 @@ object ParquetUtils { isCaseSensitive) val converter = new RowToColumnConverter(aggSchema) val columnVectors = if (offHeap) { - OffHeapColumnVector.allocateColumns(columnBatchSize, aggSchema) + OffHeapColumnVector.allocateColumns(1, aggSchema) } else { - OnHeapColumnVector.allocateColumns(columnBatchSize, aggSchema) + OnHeapColumnVector.allocateColumns(1, aggSchema) } converter.convert(row, columnVectors.toArray) new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 4f2d519d228f7..040b3eecae5de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -126,7 +126,7 @@ case class ParquetPartitionReaderFactory( } } else { new PartitionReader[InternalRow] { - var hasNext = true + private var hasNext = true override def next(): Boolean = hasNext @@ -161,7 +161,7 @@ case class ParquetPartitionReaderFactory( } } else { new PartitionReader[ColumnarBatch] { - var hasNext = true + private var hasNext = true override def next(): Boolean = hasNext @@ -169,7 +169,7 @@ case class ParquetPartitionReaderFactory( hasNext = false val footer = getFooter(file) ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, columnBatchSize, enableOffHeapColumnVector, + aggregation.get, readDataSchema, enableOffHeapColumnVector, getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala index e06ca2ae2ef23..0eb8371e97c31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -513,7 +513,6 @@ class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { - // TODO: enable Parquet V2 write path after file source V2 writers are workable. override protected def sparkConf: SparkConf = super .sparkConf diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala index 27115bff902a1..8dab4059a85b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala @@ -208,7 +208,7 @@ class HiveParquetSourceSuite extends ParquetPartitioningTest { // We can have something like MAX(C) in column name for aggregate push down // ignore this test for now - ignore("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { + test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { withTempDir { tempDir => val filePath = new File(tempDir, "testParquet").getCanonicalPath val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath From 9ebe889c3f58c608844c1b4a4e95867618974dee Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 1 Sep 2021 14:19:30 -0700 Subject: [PATCH 12/18] fix space --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../execution/datasources/parquet/ParquetSchemaConverter.scala | 2 +- .../org/apache/spark/sql/hive/HiveParquetSourceSuite.scala | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2f74815e8950d..924de9a2ed4f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -856,7 +856,7 @@ object SQLConf { val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + " down to Parquet for optimization. MAX/MIN/COUNT for Complex type and Timestamp" + - "can't be pushed down") + " can't be pushed down") .version("3.3.0") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index a2f60351e57c4..bc2b2e35dfbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -586,7 +586,7 @@ private[sql] object ParquetSchemaConverter { def checkFieldName(name: String): Unit = { // ,;{}\n\t= and space are special characters in Parquet schema - if (name.matches(".*[ ,;{}()\n\t=].*")) { + if (name.matches(".*[ ,;{}()\n\t=].*")) { throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala index 8dab4059a85b3..b3ea54a7bc931 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSourceSuite.scala @@ -206,8 +206,6 @@ class HiveParquetSourceSuite extends ParquetPartitioningTest { } } - // We can have something like MAX(C) in column name for aggregate push down - // ignore this test for now test("Aggregation attribute names can't contain special chars \" ,;{}()\\n\\t=\"") { withTempDir { tempDir => val filePath = new File(tempDir, "testParquet").getCanonicalPath From 0aadd50eaa5c5ed98864774cfae07e52e81fc2a1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 1 Sep 2021 14:21:20 -0700 Subject: [PATCH 13/18] remove unnessary change --- .../execution/datasources/parquet/ParquetSchemaConverter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index bc2b2e35dfbef..e91a3ce29b79a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -585,7 +585,7 @@ private[sql] object ParquetSchemaConverter { Types.buildMessage().named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) def checkFieldName(name: String): Unit = { - // ,;{}\n\t= and space are special characters in Parquet schema + // ,;{}()\n\t= and space are special characters in Parquet schema if (name.matches(".*[ ,;{}()\n\t=].*")) { throw QueryCompilationErrors.columnNameContainsInvalidCharactersError(name) } From a63e34c5aa71ba07a3c07d74720c024a1ff407ed Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 9 Sep 2021 08:04:24 -0700 Subject: [PATCH 14/18] resolve conflict --- .../sql/execution/datasources/v2/parquet/ParquetScan.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 6a02125038a66..ba56f5eda232d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -128,10 +128,6 @@ case class ParquetScan( Map("PushedGroupBy" -> pushedGroupByStr) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { a.aggregateExpressions.sortBy(_.hashCode()) .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && From 358d1ce60aadd7e3e316ee3190ad98aff200a6c5 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Thu, 23 Sep 2021 14:49:54 -0700 Subject: [PATCH 15/18] address comments --- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../datasources/parquet/ParquetUtils.scala | 46 +++++++++---------- .../ParquetPartitionReaderFactory.scala | 46 ++++++++++++++----- 3 files changed, 58 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 924de9a2ed4f7..98aad1c9c83f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -855,7 +855,7 @@ object SQLConf { val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + - " down to Parquet for optimization. MAX/MIN/COUNT for Complex type and Timestamp" + + " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" + " can't be pushed down") .version("3.3.0") .booleanConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 420936e3a63e1..98eae9515cc13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -161,39 +161,39 @@ object ParquetUtils { aggSchema: StructType, datetimeRebaseMode: LegacyBehaviorPolicy.Value, isCaseSensitive: Boolean): InternalRow = { - val (primitiveType, values) = + val (primitiveTypes, values) = getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) - val builder = Types.buildMessage() - primitiveType.foreach(t => builder.addField(t)) + val builder = Types.buildMessage + primitiveTypes.foreach(t => builder.addField(t)) val parquetSchema = builder.named("root") val schemaConverter = new ParquetToSparkSchemaConverter val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) - val primitiveTypeName = primitiveType.map(_.getPrimitiveTypeName) + val primitiveTypeName = primitiveTypes.map(_.getPrimitiveTypeName) primitiveTypeName.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => val v = values(i).asInstanceOf[Boolean] - converter.getConverter(i).asPrimitiveConverter().addBoolean(v) + converter.getConverter(i).asPrimitiveConverter.addBoolean(v) case (PrimitiveType.PrimitiveTypeName.INT32, i) => val v = values(i).asInstanceOf[Integer] - converter.getConverter(i).asPrimitiveConverter().addInt(v) + converter.getConverter(i).asPrimitiveConverter.addInt(v) case (PrimitiveType.PrimitiveTypeName.INT64, i) => val v = values(i).asInstanceOf[Long] - converter.getConverter(i).asPrimitiveConverter().addLong(v) + converter.getConverter(i).asPrimitiveConverter.addLong(v) case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => val v = values(i).asInstanceOf[Float] - converter.getConverter(i).asPrimitiveConverter().addFloat(v) + converter.getConverter(i).asPrimitiveConverter.addFloat(v) case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => val v = values(i).asInstanceOf[Double] - converter.getConverter(i).asPrimitiveConverter().addDouble(v) + converter.getConverter(i).asPrimitiveConverter.addDouble(v) case (PrimitiveType.PrimitiveTypeName.BINARY, i) => val v = values(i).asInstanceOf[Binary] - converter.getConverter(i).asPrimitiveConverter().addBinary(v) + converter.getConverter(i).asPrimitiveConverter.addBinary(v) case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => val v = values(i).asInstanceOf[Binary] - converter.getConverter(i).asPrimitiveConverter().addBinary(v) + converter.getConverter(i).asPrimitiveConverter.addBinary(v) case (_, i) => throw new SparkException("Unexpected parquet type name: " + primitiveTypeName(i)) } @@ -257,7 +257,7 @@ object ParquetUtils { val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] val valuesBuilder = mutable.ArrayBuilder.make[Any] - aggregation.aggregateExpressions().foreach { agg => + aggregation.aggregateExpressions.foreach { agg => var value: Any = None var rowCount = 0L var isCount = false @@ -287,13 +287,13 @@ object ParquetUtils { rowCount += block.getRowCount var isPartitionCol = false if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)) - .toSet.contains(count.column().fieldNames.head)) { + .toSet.contains(count.column.fieldNames.head)) { isPartitionCol = true } isCount = true - if(!isPartitionCol) { + if (!isPartitionCol) { index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) - // Count(*) includes the null values, but Count (colName) doesn't. + // Count(*) includes the null values, but Count(colName) doesn't. rowCount -= getNumNulls(blockMetaData, index) } case _: CountStar => @@ -309,13 +309,13 @@ object ParquetUtils { } else { valuesBuilder += value val field = fields.get(index) - primitiveTypeBuilder += Types.required(field.asPrimitiveType().getPrimitiveTypeName) + primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName) .as(field.getLogicalTypeAnnotation) - .length(field.asPrimitiveType().getTypeLength) + .length(field.asPrimitiveType.getTypeLength) .named(schemaName) } } - (primitiveTypeBuilder.result(), valuesBuilder.result()) + (primitiveTypeBuilder.result, valuesBuilder.result) } /** @@ -327,23 +327,23 @@ object ParquetUtils { columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int, isMax: Boolean): Any = { - val statistics = columnChunkMetaData.get(i).getStatistics() + val statistics = columnChunkMetaData.get(i).getStatistics if (!statistics.hasNonNullValue) { throw new UnsupportedOperationException("No min/max found for Parquet file, Set SQLConf" + s" ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } else { - if (isMax) statistics.genericGetMax() else statistics.genericGetMin() + if (isMax) statistics.genericGetMax else statistics.genericGetMin } } private def getNumNulls( columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int): Long = { - val statistics = columnChunkMetaData.get(i).getStatistics() - if (!statistics.isNumNullsSet()) { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.isNumNullsSet) { throw new UnsupportedOperationException("Number of nulls not set for Parquet file." + s" Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } - statistics.getNumNulls(); + statistics.getNumNulls; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 040b3eecae5de..a58a2192a50b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -83,7 +83,6 @@ case class ParquetPartitionReaderFactory( private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead - private val columnBatchSize = sqlConf.columnBatchSize private def getFooter(file: PartitionedFile): ParquetMetadata = { val conf = broadcastedConf.value.value @@ -92,6 +91,12 @@ case class ParquetPartitionReaderFactory( if (aggregation.isEmpty) { ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) } else { + // For aggregate push down, we will get max/min/count from footer statistics. + // We want to read the footer for the whole file instead of reading multiple + // footers for every split of the file. Basically if the start (the beginning of) + // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means + // that we have already read footer for that file, so we will skip reading again. + if (file.start != 0) return null ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) } } @@ -127,15 +132,23 @@ case class ParquetPartitionReaderFactory( } else { new PartitionReader[InternalRow] { private var hasNext = true - - override def next(): Boolean = hasNext + private lazy val row: InternalRow = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, + aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), + isCaseSensitive) + } else { + null + } + } + override def next(): Boolean = { + hasNext && row != null + } override def get(): InternalRow = { hasNext = false - val footer = getFooter(file) - ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), - isCaseSensitive) + row } override def close(): Unit = {} @@ -162,15 +175,24 @@ case class ParquetPartitionReaderFactory( } else { new PartitionReader[ColumnarBatch] { private var hasNext = true + private var row: ColumnarBatch = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, + aggregation.get, readDataSchema, enableOffHeapColumnVector, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + } else { + null + } + } - override def next(): Boolean = hasNext + override def next(): Boolean = { + hasNext && row != null + } override def get(): ColumnarBatch = { hasNext = false - val footer = getFooter(file) - ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, enableOffHeapColumnVector, - getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + row } override def close(): Unit = {} From 9e7560b9c621610bf4e4bdba2055faca4c5beb55 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 1 Oct 2021 15:28:35 -0700 Subject: [PATCH 16/18] rebase --- .../v2/parquet/ParquetScanBuilder.scala | 99 +++++++++++++++++-- 1 file changed, 93 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 9a0e4b4794fc0..4d7b266f7a8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.Scan -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} +import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ParquetScanBuilder( @@ -35,7 +37,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates{ lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -70,6 +73,10 @@ case class ParquetScanBuilder( } } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters @@ -79,8 +86,88 @@ case class ParquetScanBuilder( // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + override def pushAggregation(aggregation: Aggregation): Boolean = { + + def getStructFieldForCol(col: NamedReference): StructField = { + schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head)) + } + + def isPartitionCol(col: NamedReference) = { + (readPartitionSchema.fields.map(PartitioningUtils + .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis)) + .toSet.contains(col.fieldNames.head)) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (column, aggType) = agg match { + case max: Max => (max.column, "max") + case min: Min => (min.column, "min") + case _ => throw new IllegalArgumentException("Unexpected type of AggregateFunc") + } + + if (isPartitionCol(column)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(column) + + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType => + false + case _ => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + } + } + + if (!sparkSession.sessionState.conf.parquetAggregatePushDown || + aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) { + // Parquet footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // Todo: 1. add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + // 2. add support if filter col is partition col + // (https://issues.apache.org/jira/browse/SPARK-36647) + return false + } + + aggregation.groupByColumns.foreach { col => + if (col.fieldNames.length != 1) return false + finalSchema = finalSchema.add(getStructFieldForCol(col)) + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return false + case min: Min => + if (!processMinOrMax(min)) return false + case count: Count => + if (count.column.fieldNames.length != 1 || count.isDistinct) return false + finalSchema = + finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return false + } + this.pushedAggregations = Some(aggregation) + true + } + override def build(): Scan = { - ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) finalSchema = readDataSchema() + ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, + partitionFilters, dataFilters) } } From df6ca868b8b27dc74d4ec198dd04b3aee2c59d69 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Fri, 1 Oct 2021 17:16:04 -0700 Subject: [PATCH 17/18] address comments --- .../datasources/parquet/ParquetUtils.scala | 31 ++++++++++++------- .../ParquetPartitionReaderFactory.scala | 12 +++---- .../v2/parquet/ParquetScanBuilder.scala | 3 +- .../ParquetAggregatePushDownSuite.scala | 2 -- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 98eae9515cc13..ce96e3714e425 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -155,14 +155,15 @@ object ParquetUtils { */ private[sql] def createAggInternalRowFromFooter( footer: ParquetMetadata, + filePath: String, dataSchema: StructType, partitionSchema: StructType, aggregation: Aggregation, aggSchema: StructType, datetimeRebaseMode: LegacyBehaviorPolicy.Value, isCaseSensitive: Boolean): InternalRow = { - val (primitiveTypes, values) = - getPushedDownAggResult(footer, dataSchema, partitionSchema, aggregation, isCaseSensitive) + val (primitiveTypes, values) = getPushedDownAggResult( + footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive) val builder = Types.buildMessage primitiveTypes.foreach(t => builder.addField(t)) @@ -171,8 +172,8 @@ object ParquetUtils { val schemaConverter = new ParquetToSparkSchemaConverter val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) - val primitiveTypeName = primitiveTypes.map(_.getPrimitiveTypeName) - primitiveTypeName.zipWithIndex.foreach { + val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName) + primitiveTypeNames.zipWithIndex.foreach { case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => val v = values(i).asInstanceOf[Boolean] converter.getConverter(i).asPrimitiveConverter.addBoolean(v) @@ -195,7 +196,7 @@ object ParquetUtils { val v = values(i).asInstanceOf[Binary] converter.getConverter(i).asPrimitiveConverter.addBinary(v) case (_, i) => - throw new SparkException("Unexpected parquet type name: " + primitiveTypeName(i)) + throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) } converter.currentRecord } @@ -211,6 +212,7 @@ object ParquetUtils { */ private[sql] def createAggColumnarBatchFromFooter( footer: ParquetMetadata, + filePath: String, dataSchema: StructType, partitionSchema: StructType, aggregation: Aggregation, @@ -220,6 +222,7 @@ object ParquetUtils { isCaseSensitive: Boolean): ColumnarBatch = { val row = createAggInternalRowFromFooter( footer, + filePath, dataSchema, partitionSchema, aggregation, @@ -246,6 +249,7 @@ object ParquetUtils { */ private[sql] def getPushedDownAggResult( footer: ParquetMetadata, + filePath: String, dataSchema: StructType, partitionSchema: StructType, aggregation: Aggregation, @@ -270,7 +274,7 @@ object ParquetUtils { val colName = max.column.fieldNames.head index = dataSchema.fieldNames.toList.indexOf(colName) schemaName = "max(" + colName + ")" - val currentMax = getCurrentBlockMaxOrMin(blockMetaData, index, true) + val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true) if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { value = currentMax } @@ -278,7 +282,7 @@ object ParquetUtils { val colName = min.column.fieldNames.head index = dataSchema.fieldNames.toList.indexOf(colName) schemaName = "min(" + colName + ")" - val currentMin = getCurrentBlockMaxOrMin(blockMetaData, index, false) + val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false) if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { value = currentMin } @@ -294,7 +298,7 @@ object ParquetUtils { if (!isPartitionCol) { index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head) // Count(*) includes the null values, but Count(colName) doesn't. - rowCount -= getNumNulls(blockMetaData, index) + rowCount -= getNumNulls(filePath, blockMetaData, index) } case _: CountStar => schemaName = "count(*)" @@ -324,25 +328,28 @@ object ParquetUtils { * @return the Max or Min value */ private def getCurrentBlockMaxOrMin( + filePath: String, columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int, isMax: Boolean): Any = { val statistics = columnChunkMetaData.get(i).getStatistics if (!statistics.hasNonNullValue) { - throw new UnsupportedOperationException("No min/max found for Parquet file, Set SQLConf" + - s" ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") + throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " + + s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") } else { if (isMax) statistics.genericGetMax else statistics.genericGetMin } } private def getNumNulls( + filePath: String, columnChunkMetaData: util.List[ColumnChunkMetaData], i: Int): Long = { val statistics = columnChunkMetaData.get(i).getStatistics if (!statistics.isNumNullsSet) { - throw new UnsupportedOperationException("Number of nulls not set for Parquet file." + - s" Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") + throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" + + s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" + + s" again") } statistics.getNumNulls; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index a58a2192a50b5..111018b579ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -135,9 +135,9 @@ case class ParquetPartitionReaderFactory( private lazy val row: InternalRow = { val footer = getFooter(file) if (footer != null && footer.getBlocks.size > 0) { - ParquetUtils.createAggInternalRowFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, getDatetimeRebaseMode(footer.getFileMetaData), - isCaseSensitive) + ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) } else { null } @@ -175,11 +175,11 @@ case class ParquetPartitionReaderFactory( } else { new PartitionReader[ColumnarBatch] { private var hasNext = true - private var row: ColumnarBatch = { + private val row: ColumnarBatch = { val footer = getFooter(file) if (footer != null && footer.getBlocks.size > 0) { - ParquetUtils.createAggColumnarBatchFromFooter(footer, dataSchema, partitionSchema, - aggregation.get, readDataSchema, enableOffHeapColumnVector, + ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector, getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) } else { null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 4d7b266f7a8ba..fbddd062ea1a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -102,7 +102,8 @@ case class ParquetScanBuilder( val (column, aggType) = agg match { case max: Max => (max.column, "max") case min: Min => (min.column, "min") - case _ => throw new IllegalArgumentException("Unexpected type of AggregateFunc") + case _ => + throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}") } if (isPartitionCol(column)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala index 0eb8371e97c31..c795bd9ff3389 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala @@ -445,8 +445,6 @@ abstract class ParquetAggregatePushDownSuite "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, 12345.678, ("2021-01-01").date))) - val testCountStar = sql("SELECT count(*) FROM test") - val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + " count(LongCol), count(FloatCol), count(DoubleCol)," + From 51ef0f2c919add3b75f483158a4e56b599dedc7e Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 10 Oct 2021 11:43:35 -0700 Subject: [PATCH 18/18] address comments --- .../scala/org/apache/spark/sql/types/StructType.scala | 2 +- .../sql/execution/datasources/parquet/ParquetUtils.scala | 1 + .../sql/execution/datasources/v2/FileScanBuilder.scala | 2 +- .../execution/datasources/v2/parquet/ParquetScan.scala | 5 ++++- .../datasources/v2/parquet/ParquetScanBuilder.scala | 8 +++----- 5 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index c9862cb629cff..50b197fb9aea3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru def names: Array[String] = fieldNames private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap override def equals(that: Any): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index ce96e3714e425..1093f9c5aa51b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -261,6 +261,7 @@ object ParquetUtils { val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] val valuesBuilder = mutable.ArrayBuilder.make[Any] + assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down") aggregation.aggregateExpressions.foreach { agg => var value: Any = None var rowCount = 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 309f045201140..2dc4137d6f9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -96,6 +96,6 @@ abstract class FileScanBuilder( private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - private val partitionNameSet: Set[String] = + val partitionNameSet: Set[String] = partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index ba56f5eda232d..42dc287f73129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -49,8 +49,11 @@ case class ParquetScan( dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true - override def readSchema(): StructType = + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder` + // and no need to call super.readSchema() if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema() + } override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index fbddd062ea1a4..c579867623e1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} -import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy @@ -89,13 +89,11 @@ case class ParquetScanBuilder( override def pushAggregation(aggregation: Aggregation): Boolean = { def getStructFieldForCol(col: NamedReference): StructField = { - schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head)) + schema.nameToField(col.fieldNames.head) } def isPartitionCol(col: NamedReference) = { - (readPartitionSchema.fields.map(PartitioningUtils - .getColName(_, sparkSession.sessionState.conf.caseSensitiveAnalysis)) - .toSet.contains(col.fieldNames.head)) + partitionNameSet.contains(col.fieldNames.head) } def processMinOrMax(agg: AggregateFunc): Boolean = {