diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a52..0ad3862f6cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -67,4 +67,6 @@ object OrcOptions { "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index ef67ea7d17ce..f36a89a4c3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -82,4 +82,8 @@ object ParquetOptions { "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = { + shortParquetCompressionCodecNames(name).name() + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9..802ddafdbee4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 9a6607f2f2c6..e484356906e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,18 +55,28 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } val committer = FileCommitProtocol.instantiate( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 000000000000..d10a6f25c64f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} +import org.apache.spark.sql.hive.orc.OrcFileOperator +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf + +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { + import spark.implicits._ + + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 50 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) + } + } + + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format.toLowerCase match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter { file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + codecs.distinct + } + + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + } + + private def writeDataToTable( + tableName: String, + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString + sql( + s""" + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) + } + + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + totalSize = getTableSize(path) + } + } + } + assert(totalSize > 0L) + totalSize + } + + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String], + usingCTAS: Boolean) + (assertion: (String, Long) => Unit): Unit = { + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" + withTempDir { tmpDir => + withTable(tableName) { + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) + val tableSize = getTableSize(path) + assertion(relCompressionCodecs.head, tableSize) + } + } + } + + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } + } + } + + // When the amount of data is small, compressed data size may be larger than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + usingCTAS: Boolean, + tableSize: Long): Boolean = { + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize + } + } + + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // For non-partitioned table and when convertMetastore is true, Expect session-level + // take effect, and in other cases expect table-level take effect + // TODO: It should always be table-level taking effect when the bug(SPARK-22926) + // is fixed + val expectCodec = + if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCodec, sessionCodec, realCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) + } + } + } + } + } + + test("both table-level and session-level compression are set") { + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + test("table-level compression is not set but session-level compressions is set ") { + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partitionValue = if (isPartitioned) Some(compressionCodec) else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partitionValue) } + } + val tablePath = getTablePartitionPath(tmpDir, tableName, None) + val realCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else { + getTableCompressionCodec(tablePath, format) + } + + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } + } + } + } + + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) + } +}