Skip to content
22 changes: 13 additions & 9 deletions core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -53,10 +53,14 @@ private[spark] object CompressionCodec {
|| codec.isInstanceOf[LZ4CompressionCodec])
}

private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
"snappy" -> classOf[SnappyCompressionCodec].getName)
/** Maps the short versions of compression codec names to fully-qualified class names. */
private val shortCompressionCodecNameMapper = new ShortCompressionCodecNameMapper {
override def lz4: Option[String] = Some(classOf[LZ4CompressionCodec].getName)
override def lzf: Option[String] = Some(classOf[LZFCompressionCodec].getName)
override def snappy: Option[String] = Some(classOf[SnappyCompressionCodec].getName)
}

private val shortCompressionCodecMap = shortCompressionCodecNameMapper.getAsMap

def getCodecName(conf: SparkConf): String = {
conf.get(configKey, DEFAULT_COMPRESSION_CODEC)
Expand All @@ -67,7 +71,7 @@ private[spark] object CompressionCodec {
}

def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codecClass = shortCompressionCodecNameMapper.get(codecName).getOrElse(codecName)
val codec = try {
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
Expand All @@ -84,18 +88,18 @@ private[spark] object CompressionCodec {
* If it is already a short name, just return it.
*/
def getShortName(codecName: String): String = {
if (shortCompressionCodecNames.contains(codecName)) {
if (shortCompressionCodecMap.contains(codecName)) {
codecName
} else {
shortCompressionCodecNames
shortCompressionCodecMap
.collectFirst { case (k, v) if v == codecName => k }
.getOrElse { throw new IllegalArgumentException(s"No short name for codec $codecName.") }
}
}

val FALLBACK_COMPRESSION_CODEC = "snappy"
val DEFAULT_COMPRESSION_CODEC = "lz4"
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
val ALL_COMPRESSION_CODECS = shortCompressionCodecMap.values.toSeq
}

/**
Expand Down
45 changes: 45 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,51 @@ private[spark] object CallSite {
val empty = CallSite("", "")
}

/** An utility class to map short compression codec names to qualified ones. */
private[spark] class ShortCompressionCodecNameMapper {

def get(codecName: String): Option[String] = codecName.toLowerCase match {
case "none" => none
case "uncompressed" => uncompressed
case "bzip2" => bzip2
case "deflate" => deflate
case "gzip" => gzip
case "lzo" => lzo
case "lz4" => lz4
case "lzf" => lzf
case "snappy" => snappy
case _ => None
}

def getAsMap: Map[String, String] = {
Seq(
("none", none),
("uncompressed", uncompressed),
("bzip2", bzip2),
("deflate", deflate),
("gzip", gzip),
("lzo", lzo),
("lz4", lz4),
("lzf", lzf),
("snappy", snappy)
).flatMap { case (shortCodecName, codecName) =>
if (codecName.isDefined) Some(shortCodecName, codecName.get) else None
}.toMap
}

// To support short codec names, derived classes need to override the methods below that return
// corresponding qualified codec names.
def none: Option[String] = None
def uncompressed: Option[String] = None
def bzip2: Option[String] = None
def deflate: Option[String] = None
def gzip: Option[String] = None
def lzo: Option[String] = None
def lz4: Option[String] = None
def lzf: Option[String] = None
def snappy: Option[String] = None
}

/**
* Various utility methods used by Spark.
*/
Expand Down
30 changes: 15 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.io.IOException

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat

Expand Down Expand Up @@ -58,7 +59,8 @@ import org.apache.spark.util.Utils
private[sql] case class InsertIntoHadoopFsRelation(
@transient relation: HadoopFsRelation,
@transient query: LogicalPlan,
mode: SaveMode)
mode: SaveMode,
compressionCodec: Option[Class[_ <: CompressionCodec]] = None)
extends RunnableCommand {

override def run(sqlContext: SQLContext): Seq[Row] = {
Expand Down Expand Up @@ -126,7 +128,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
""".stripMargin)

val writerContainer = if (partitionColumns.isEmpty && relation.maybeBucketSpec.isEmpty) {
new DefaultWriterContainer(relation, job, isAppend)
new DefaultWriterContainer(relation, job, isAppend, compressionCodec)
} else {
val output = df.queryExecution.executedPlan.output
val (partitionOutput, dataOutput) =
Expand All @@ -140,7 +142,8 @@ private[sql] case class InsertIntoHadoopFsRelation(
output,
PartitioningUtils.DEFAULT_PARTITION_NAME,
sqlContext.conf.getConf(SQLConf.PARTITION_MAX_FILES),
isAppend)
isAppend,
compressionCodec)
}

// This call shouldn't be put into the `try` block below because it only initializes and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.language.{existentials, implicitConversions}
import scala.util.{Failure, Success, Try}

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress._
import org.apache.hadoop.util.StringUtils

import org.apache.spark.Logging
Expand All @@ -32,7 +33,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{CalendarIntervalType, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ShortCompressionCodecNameMapper, Utils}

case class ResolvedDataSource(provider: Class[_], relation: BaseRelation)

Expand All @@ -49,6 +50,14 @@ object ResolvedDataSource extends Logging {
"org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName
)

/** Maps the short versions of compression codec names to fully-qualified class names. */
private val hadoopShortCodecNameMapper = new ShortCompressionCodecNameMapper {
override def bzip2: Option[String] = Some(classOf[BZip2Codec].getCanonicalName)
override def deflate: Option[String] = Some(classOf[DeflateCodec].getCanonicalName)
override def gzip: Option[String] = Some(classOf[GzipCodec].getCanonicalName)
override def snappy: Option[String] = Some(classOf[SnappyCodec].getCanonicalName)
}

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider0: String): Class[_] = {
val provider = backwardCompatibilityMap.getOrElse(provider0, provider0)
Expand Down Expand Up @@ -286,14 +295,25 @@ object ResolvedDataSource extends Logging {
bucketSpec,
caseInsensitiveOptions)

val compressionCodec = options
.get("compressionCodec")
.map { codecName =>
val codecFactory = new CompressionCodecFactory(
sqlContext.sparkContext.hadoopConfiguration)
val resolvedCodecName = hadoopShortCodecNameMapper.get(codecName).getOrElse(codecName)
Option(codecFactory.getCodecClassByName(resolvedCodecName))
}
.getOrElse(None)

// For partitioned relation r, r.schema's column ordering can be different from the column
// ordering of data.logicalPlan (partition columns are all moved after data column). This
// will be adjusted within InsertIntoHadoopFsRelation.
sqlContext.executePlan(
InsertIntoHadoopFsRelation(
r,
data.logicalPlan,
mode)).toRdd
mode,
compressionCodec)).toRdd
r
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}

import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
Expand All @@ -39,7 +40,8 @@ import org.apache.spark.util.SerializableConfiguration
private[sql] abstract class BaseWriterContainer(
@transient val relation: HadoopFsRelation,
@transient private val job: Job,
isAppend: Boolean)
isAppend: Boolean,
compressionCodec: Option[Class[_ <: CompressionCodec]] = None)
extends Logging with Serializable {

protected val dataSchema = relation.dataSchema
Expand Down Expand Up @@ -207,6 +209,11 @@ private[sql] abstract class BaseWriterContainer(
serializableConf.value.set("mapred.task.id", taskAttemptId.toString)
serializableConf.value.setBoolean("mapred.task.is.map", true)
serializableConf.value.setInt("mapred.task.partition", 0)
compressionCodec.map { codecClass =>
serializableConf.value.set("mapred.output.compress", "true")
serializableConf.value.set("mapred.output.compression.codec", codecClass.getCanonicalName)
serializableConf.value.set("mapred.output.compression.type", "BLOCK")
}
}

def commitTask(): Unit = {
Expand Down Expand Up @@ -239,8 +246,9 @@ private[sql] abstract class BaseWriterContainer(
private[sql] class DefaultWriterContainer(
relation: HadoopFsRelation,
job: Job,
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
isAppend: Boolean,
compressionCodec: Option[Class[_ <: CompressionCodec]])
extends BaseWriterContainer(relation, job, isAppend, compressionCodec) {

def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
executorSideSetup(taskContext)
Expand Down Expand Up @@ -308,8 +316,9 @@ private[sql] class DynamicPartitionWriterContainer(
inputSchema: Seq[Attribute],
defaultPartitionName: String,
maxOpenFiles: Int,
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
isAppend: Boolean,
compressionCodec: Option[Class[_ <: CompressionCodec]])
extends BaseWriterContainer(relation, job, isAppend, compressionCodec) {

private val bucketSpec = relation.maybeBucketSpec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser
import org.apache.spark.sql.execution.datasources.{PartitionSpec, _}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.{SerializableConfiguration, ShortCompressionCodecNameMapper, Utils}

private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister {

Expand Down Expand Up @@ -283,10 +283,8 @@ private[sql] class ParquetRelation(
conf.set(
ParquetOutputFormat.COMPRESSION,
ParquetRelation
.shortParquetCompressionCodecNames
.getOrElse(
sqlContext.conf.parquetCompressionCodec.toUpperCase,
CompressionCodecName.UNCOMPRESSED).name())
.parquetShortCodecNameMapper.get(sqlContext.conf.parquetCompressionCodec)
.getOrElse(CompressionCodecName.UNCOMPRESSED.name()))

new BucketedOutputWriterFactory {
override def newInstance(
Expand Down Expand Up @@ -902,11 +900,12 @@ private[sql] object ParquetRelation extends Logging {
}
}

// The parquet compression short names
val shortParquetCompressionCodecNames = Map(
"NONE" -> CompressionCodecName.UNCOMPRESSED,
"UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED,
"SNAPPY" -> CompressionCodecName.SNAPPY,
"GZIP" -> CompressionCodecName.GZIP,
"LZO" -> CompressionCodecName.LZO)
/** Maps the short versions of compression codec names to qualified compression names. */
val parquetShortCodecNameMapper = new ShortCompressionCodecNameMapper {
override def none: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
override def uncompressed: Option[String] = Some(CompressionCodecName.UNCOMPRESSED.name())
override def gzip: Option[String] = Some(CompressionCodecName.GZIP.name())
override def lzo: Option[String] = Some(CompressionCodecName.LZO.name())
override def snappy: Option[String] = Some(CompressionCodecName.SNAPPY.name())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.text

import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -57,6 +57,15 @@ class TextSuite extends QueryTest with SharedSQLContext {
}
}

test("compression") {
Seq("bzip2", "deflate", "gzip").map { codecName =>
val tempDirPath = Utils.createTempDir().getAbsolutePath
val df = sqlContext.read.text(testFile)
df.write.option("compressionCodec", codecName).mode(SaveMode.Overwrite).text(tempDirPath)
verifyFrame(sqlContext.read.text(tempDirPath))
}
}

private def testFile: String = {
Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {

val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt")
df.queryExecution.sparkPlan match {
case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK
case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _, _)) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[ParquetRelation].getCanonicalName} and " +
s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " +
Expand Down Expand Up @@ -336,7 +336,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {

val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array")
df.queryExecution.sparkPlan match {
case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK
case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _, _)) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[ParquetRelation].getCanonicalName} and " +
s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." +
Expand Down