diff --git a/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java b/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java index 5cfcfffd07a28..c927991425cd5 100644 --- a/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java +++ b/connector/avro/src/main/java/org/apache/spark/sql/avro/AvroCompressionCodec.java @@ -22,29 +22,43 @@ import java.util.Map; import java.util.stream.Collectors; -import org.apache.avro.file.DataFileConstants; +import org.apache.avro.file.*; /** * A mapper class from Spark supported avro compression codecs to avro compression codecs. */ public enum AvroCompressionCodec { - UNCOMPRESSED(DataFileConstants.NULL_CODEC), - DEFLATE(DataFileConstants.DEFLATE_CODEC), - SNAPPY(DataFileConstants.SNAPPY_CODEC), - BZIP2(DataFileConstants.BZIP2_CODEC), - XZ(DataFileConstants.XZ_CODEC), - ZSTANDARD(DataFileConstants.ZSTANDARD_CODEC); + UNCOMPRESSED(DataFileConstants.NULL_CODEC, false, -1), + DEFLATE(DataFileConstants.DEFLATE_CODEC, true, CodecFactory.DEFAULT_DEFLATE_LEVEL), + SNAPPY(DataFileConstants.SNAPPY_CODEC, false, -1), + BZIP2(DataFileConstants.BZIP2_CODEC, false, -1), + XZ(DataFileConstants.XZ_CODEC, true, CodecFactory.DEFAULT_XZ_LEVEL), + ZSTANDARD(DataFileConstants.ZSTANDARD_CODEC, true, CodecFactory.DEFAULT_ZSTANDARD_LEVEL); private final String codecName; + private final boolean supportCompressionLevel; + private final int defaultCompressionLevel; - AvroCompressionCodec(String codecName) { + AvroCompressionCodec( + String codecName, + boolean supportCompressionLevel, int defaultCompressionLevel) { this.codecName = codecName; + this.supportCompressionLevel = supportCompressionLevel; + this.defaultCompressionLevel = defaultCompressionLevel; } public String getCodecName() { return this.codecName; } + public boolean getSupportCompressionLevel() { + return this.supportCompressionLevel; + } + + public int getDefaultCompressionLevel() { + return this.defaultCompressionLevel; + } + private static final Map codecNameMap = Arrays.stream(AvroCompressionCodec.values()).collect( Collectors.toMap(codec -> codec.name(), codec -> codec.name().toLowerCase(Locale.ROOT))); diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 3910cf540628e..d9c88e14d039e 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._ import org.apache.avro.Schema import org.apache.avro.file.{DataFileReader, FileReader} import org.apache.avro.generic.{GenericDatumReader, GenericRecord} -import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapred.FsInput import org.apache.avro.mapreduce.AvroJob import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus @@ -110,10 +110,12 @@ private[sql] object AvroUtils extends Logging { case compressed => job.getConfiguration.setBoolean("mapred.output.compress", true) job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, compressed.getCodecName) - if (compressed == DEFLATE) { - val deflateLevel = sqlConf.avroDeflateLevel - logInfo(s"Compressing Avro output using the $codecName codec at level $deflateLevel") - job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + if (compressed.getSupportCompressionLevel) { + val level = sqlConf.getConfString(s"spark.sql.avro.$codecName.level", + compressed.getDefaultCompressionLevel.toString) + logInfo(s"Compressing Avro output using the $codecName codec at level $level") + val s = if (compressed == ZSTANDARD) "zstd" else codecName + job.getConfiguration.setInt(s"avro.mapred.$s.level", level.toInt) } else { logInfo(s"Compressing Avro output using the $codecName codec") } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala index 933b3f989ef7c..256b608feaa1f 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.avro import java.util.Locale import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.FileSourceCodecSuite import org.apache.spark.sql.internal.SQLConf @@ -58,4 +59,20 @@ class AvroCodecSuite extends FileSourceCodecSuite { parameters = Map("codecName" -> "unsupported") ) } + + test("SPARK-46759: compression level support for zstandard codec") { + Seq("9", "1").foreach { level => + withSQLConf( + (SQLConf.AVRO_COMPRESSION_CODEC.key -> "zstandard"), + (SQLConf.AVRO_ZSTANDARD_LEVEL.key -> level)) { + withTable("avro_t") { + sql( + s"""CREATE TABLE avro_t + |USING $format + |AS SELECT 1 as id""".stripMargin) + checkAnswer(spark.table("avro_t"), Seq(Row(1))) + } + } + } + } } diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index 898afe9de87ff..594322f5dc835 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -362,6 +362,24 @@ Configuration of Avro can be done via `spark.conf.set` or by running `SET key=va 2.4.0 + + spark.sql.avro.xz.level + 6 + + Compression level for the xz codec used in writing of AVRO files. Valid value must be in + the range of from 1 to 9 inclusive. The default value is 6 in the current implementation. + + 4.0.0 + + + spark.sql.avro.zstandard.level + 3 + + Compression level for the zstandard codec used in writing of AVRO files. + The default value is 3 in the current implementation. + + 4.0.0 + spark.sql.avro.datetimeRebaseModeInRead EXCEPTION diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index eb5233bfb1231..61c7b2457b11e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3625,7 +3625,23 @@ object SQLConf { .version("2.4.0") .intConf .checkValues((1 to 9).toSet + Deflater.DEFAULT_COMPRESSION) - .createWithDefault(Deflater.DEFAULT_COMPRESSION) + .createOptional + + val AVRO_XZ_LEVEL = buildConf("spark.sql.avro.zx.level") + .doc("Compression level for the xz codec used in writing of AVRO files. " + + "Valid value must be in the range of from 1 to 9 inclusive " + + "The default value is 6.") + .version("4.0.0") + .intConf + .checkValue(v => v > 0 && v <= 9, "The value must be in the range of from 1 to 9 inclusive.") + .createOptional + + val AVRO_ZSTANDARD_LEVEL = buildConf("spark.sql.avro.zstandard.level") + .doc("Compression level for the zstandard codec used in writing of AVRO files. " + + "The default value is 3.") + .version("4.0.0") + .intConf + .createOptional val LEGACY_SIZE_OF_NULL = buildConf("spark.sql.legacy.sizeOfNull") .internal() @@ -5421,8 +5437,6 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def avroCompressionCodec: String = getConf(SQLConf.AVRO_COMPRESSION_CODEC) - def avroDeflateLevel: Int = getConf(SQLConf.AVRO_DEFLATE_LEVEL) - def replaceDatabricksSparkAvroEnabled: Boolean = getConf(SQLConf.LEGACY_REPLACE_DATABRICKS_SPARK_AVRO_ENABLED)