diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef75..97fdc232be8f 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -25,7 +25,7 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils} /** * :: DeveloperApi :: @@ -53,10 +53,14 @@ private[spark] object CompressionCodec { || codec.isInstanceOf[LZ4CompressionCodec]) } - private val shortCompressionCodecNames = Map( - "lz4" -> classOf[LZ4CompressionCodec].getName, - "lzf" -> classOf[LZFCompressionCodec].getName, - "snappy" -> classOf[SnappyCompressionCodec].getName) + /** Maps the short versions of compression codec names to fully-qualified class names. */ + private val shortCompressionCodecNameMapper = new ShortCompressionCodecNameMapper { + override def lz4: Option[String] = Some(classOf[LZ4CompressionCodec].getName) + override def lzf: Option[String] = Some(classOf[LZFCompressionCodec].getName) + override def snappy: Option[String] = Some(classOf[SnappyCompressionCodec].getName) + } + + private val shortCompressionCodecMap = shortCompressionCodecNameMapper.getAsMap def getCodecName(conf: SparkConf): String = { conf.get(configKey, DEFAULT_COMPRESSION_CODEC) @@ -67,7 +71,7 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = shortCompressionCodecNameMapper.get(codecName).getOrElse(codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) @@ -84,10 +88,10 @@ private[spark] object CompressionCodec { * If it is already a short name, just return it. */ def getShortName(codecName: String): String = { - if (shortCompressionCodecNames.contains(codecName)) { + if (shortCompressionCodecMap.contains(codecName)) { codecName } else { - shortCompressionCodecNames + shortCompressionCodecMap .collectFirst { case (k, v) if v == codecName => k } .getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") } } @@ -95,7 +99,7 @@ private[spark] object CompressionCodec { val FALLBACK_COMPRESSION_CODEC = "snappy" val DEFAULT_COMPRESSION_CODEC = "lz4" - val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq + val ALL_COMPRESSION_CODECS = shortCompressionCodecMap.values.toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e0c9bf02a1a2..967593e737af 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -60,6 +60,51 @@ private[spark] object CallSite { val empty = CallSite("", "") } +/** An utility class to map short compression codec names to qualified ones. */ +private[spark] class ShortCompressionCodecNameMapper { + + def get(codecName: String): Option[String] = codecName.toLowerCase match { + case "none" => none + case "uncompressed" => uncompressed + case "bzip2" => bzip2 + case "deflate" => deflate + case "gzip" => gzip + case "lzo" => lzo + case "lz4" => lz4 + case "lzf" => lzf + case "snappy" => snappy + case _ => None + } + + def getAsMap: Map[String, String] = { + Seq( + ("none", none), + ("uncompressed", uncompressed), + ("bzip2", bzip2), + ("deflate", deflate), + ("gzip", gzip), + ("lzo", lzo), + ("lz4", lz4), + ("lzf", lzf), + ("snappy", snappy) + ).flatMap { case (shortCodecName, codecName) => + if (codecName.isDefined) Some(shortCodecName, codecName.get) else None + }.toMap + } + + // To support short codec names, derived classes need to override the methods below that return + // corresponding qualified codec names. + def none: Option[String] = None + def uncompressed: Option[String] = None + def bzip2: Option[String] = None + def deflate: Option[String] = None + def gzip: Option[String] = None + def lzo: Option[String] = None + def lz4: Option[String] = None + def lzf: Option[String] = None + def snappy: Option[String] = None +} + /** * Various utility methods used by Spark. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index d6bdd3d82556..c5839fc3b648 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 2d3e1714d2b7..f54e2c211454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.io.IOException import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat @@ -58,7 +59,8 @@ import org.apache.spark.util.Utils private[sql] case class InsertIntoHadoopFsRelation( @transient relation: HadoopFsRelation, @transient query: LogicalPlan, - mode: SaveMode) + mode: SaveMode, + compressionCodec: Option[Class[_ <: CompressionCodec]] = None) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { @@ -126,7 +128,7 @@ private[sql] case class InsertIntoHadoopFsRelation( """.stripMargin) val writerContainer = if (partitionColumns.isEmpty && relation.maybeBucketSpec.isEmpty) { - new DefaultWriterContainer(relation, job, isAppend) + new DefaultWriterContainer(relation, job, isAppend, compressionCodec) } else { val output = df.queryExecution.executedPlan.output val (partitionOutput, dataOutput) = @@ -140,7 +142,8 @@ private[sql] case class InsertIntoHadoopFsRelation( output, PartitioningUtils.DEFAULT_PARTITION_NAME, sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES), - isAppend) + isAppend, + compressionCodec) } // This call shouldn't be put into the `try` block below because it only initializes and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index eec9070beed6..d882d19ea758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -24,6 +24,7 @@ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress._ import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging @@ -32,7 +33,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils} case class ResolvedDataSource(provider: Class[_], relation: BaseRelation) @@ -49,6 +50,14 @@ object ResolvedDataSource extends Logging { "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName ) + /** Maps the short versions of compression codec names to fully-qualified class names. */ + private val hadoopShortCodecNameMapper = new ShortCompressionCodecNameMapper { + override def bzip2: Option[String] = Some(classOf[BZip2Codec].getCanonicalName) + override def deflate: Option[String] = Some(classOf[DeflateCodec].getCanonicalName) + override def gzip: Option[String] = Some(classOf[GzipCodec].getCanonicalName) + override def snappy: Option[String] = Some(classOf[SnappyCodec].getCanonicalName) + } + /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider0: String): Class[_] = { val provider = backwardCompatibilityMap.getOrElse(provider0, provider0) @@ -286,6 +295,16 @@ object ResolvedDataSource extends Logging { bucketSpec, caseInsensitiveOptions) + val compressionCodec = options + .get("compressionCodec") + .map { codecName => + val codecFactory = new CompressionCodecFactory( + sqlContext.sparkContext.hadoopConfiguration) + val resolvedCodecName = hadoopShortCodecNameMapper.get(codecName).getOrElse(codecName) + Option(codecFactory.getCodecClassByName(resolvedCodecName)) + } + .getOrElse(None) + // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. @@ -293,7 +312,8 @@ object ResolvedDataSource extends Logging { InsertIntoHadoopFsRelation( r, data.logicalPlan, - mode)).toRdd + mode, + compressionCodec)).toRdd r case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7e5c8f2f48d6..a8aca95e0bc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl @@ -39,7 +40,8 @@ import org.apache.spark.util.SerializableConfiguration private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, @transient private val job: Job, - isAppend: Boolean) + isAppend: Boolean, + compressionCodec: Option[Class[_ <: CompressionCodec]] = None) extends Logging with Serializable { protected val dataSchema = relation.dataSchema @@ -207,6 +209,11 @@ private[sql] abstract class BaseWriterContainer( serializableConf.value.set("mapred.task.id", taskAttemptId.toString) serializableConf.value.setBoolean("mapred.task.is.map", true) serializableConf.value.setInt("mapred.task.partition", 0) + compressionCodec.map { codecClass => + serializableConf.value.set("mapred.output.compress", "true") + serializableConf.value.set("mapred.output.compression.codec", codecClass.getCanonicalName) + serializableConf.value.set("mapred.output.compression.type", "BLOCK") + } } def commitTask(): Unit = { @@ -239,8 +246,9 @@ private[sql] abstract class BaseWriterContainer( private[sql] class DefaultWriterContainer( relation: HadoopFsRelation, job: Job, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { + isAppend: Boolean, + compressionCodec: Option[Class[_ <: CompressionCodec]]) + extends BaseWriterContainer(relation, job, isAppend, compressionCodec) { def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) @@ -308,8 +316,9 @@ private[sql] class DynamicPartitionWriterContainer( inputSchema: Seq[Attribute], defaultPartitionName: String, maxOpenFiles: Int, - isAppend: Boolean) - extends BaseWriterContainer(relation, job, isAppend) { + isAppend: Boolean, + compressionCodec: Option[Class[_ <: CompressionCodec]]) + extends BaseWriterContainer(relation, job, isAppend, compressionCodec) { private val bucketSpec = relation.maybeBucketSpec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1e686d41f41d..a4ff92fcf1a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.{SerializableConfiguration, ShortCompressionCodecNameMapper, Utils} private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { @@ -283,10 +283,8 @@ private[sql] class ParquetRelation( conf.set( ParquetOutputFormat.COMPRESSION, ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) + .parquetShortCodecNameMapper.get(sqlContext.conf.parquetCompressionCodec) + .getOrElse(CompressionCodecName.UNCOMPRESSED.name())) new BucketedOutputWriterFactory { override def newInstance( @@ -902,11 +900,12 @@ private[sql] object ParquetRelation extends Logging { } } - // The parquet compression short names - val shortParquetCompressionCodecNames = Map( - "NONE" -> CompressionCodecName.UNCOMPRESSED, - "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, - "SNAPPY" -> CompressionCodecName.SNAPPY, - "GZIP" -> CompressionCodecName.GZIP, - "LZO" -> CompressionCodecName.LZO) + /** Maps the short versions of compression codec names to qualified compression names. */ + val parquetShortCodecNameMapper = new ShortCompressionCodecNameMapper { + override def none: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name()) + override def uncompressed: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name()) + override def gzip: Option[String] = Some(CompressionCodecName.GZIP.name()) + override def lzo: Option[String] = Some(CompressionCodecName.LZO.name()) + override def snappy: Option[String] = Some(CompressionCodecName.SNAPPY.name()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index f95272530d58..67122ca68130 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.text -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils @@ -57,6 +57,15 @@ class TextSuite extends QueryTest with SharedSQLContext { } } + test("compression") { + Seq("bzip2", "deflate", "gzip").map { codecName => + val tempDirPath = Utils.createTempDir().getAbsolutePath + val df = sqlContext.read.text(testFile) + df.write.option("compressionCodec", codecName).mode(SaveMode.Overwrite).text(tempDirPath) + verifyFrame(sqlContext.read.text(tempDirPath)) + } + } + private def testFile: String = { Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index a6ca7d0386b2..e6204c37193d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -306,7 +306,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.sparkPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + @@ -336,7 +336,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.sparkPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." +