diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala index 3eba013c14353..15c76ec358edc 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.avro import org.apache.avro.Schema +import org.apache.avro.mapreduce.AvroJob import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} @@ -38,7 +39,14 @@ private[sql] class AvroOutputWriterFactory( private lazy val avroSchema = new Schema.Parser().parse(avroSchemaAsJsonString) - override def getFileExtension(context: TaskAttemptContext): String = ".avro" + override def getFileExtension(context: TaskAttemptContext): String = { + val codec = context.getConfiguration.get(AvroJob.CONF_OUTPUT_CODEC) + if (codec == null || codec.equalsIgnoreCase("null")) { + ".avro" + } else { + s".$codec.avro" + } + } override def newInstance( path: String, diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala index 4e4942e1b2e26..ec3753b84a559 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCodecSuite.scala @@ -26,4 +26,20 @@ class AvroCodecSuite extends FileSourceCodecSuite { override val codecConfigName: String = SQLConf.AVRO_COMPRESSION_CODEC.key override protected def availableCodecs = AvroCompressionCodec.values().map(_.lowerCaseName()).iterator.to(Seq) + + availableCodecs.foreach { codec => + test(s"SPARK-46746: attach codec name to avro files - codec $codec") { + withTable("avro_t") { + sql( + s"""CREATE TABLE avro_t + | USING $format OPTIONS('compression'='$codec') + | AS SELECT 1 as id + | """.stripMargin) + spark.table("avro_t") + .inputFiles.foreach { file => + assert(file.endsWith(s"$codec.avro".stripPrefix("uncompressed"))) + } + } + } + } }