diff --git a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 95835f0d4ca4..d89f96305964 100644 --- a/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/external/avro/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1 @@ -org.apache.spark.sql.avro.AvroFileFormat +org.apache.spark.sql.v2.avro.AvroDataSourceV2 diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index be8223ccc9df..123669ba1376 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -23,25 +23,23 @@ import java.net.URI import scala.util.control.NonFatal import org.apache.avro.Schema -import org.apache.avro.file.DataFileConstants._ import org.apache.avro.file.DataFileReader import org.apache.avro.generic.{GenericDatumReader, GenericRecord} -import org.apache.avro.mapred.{AvroOutputFormat, FsInput} -import org.apache.avro.mapreduce.AvroJob +import org.apache.avro.mapred.FsInput import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.Job -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.sources.{DataSourceRegister, Filter} import org.apache.spark.sql.types._ -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.SerializableConfiguration -private[avro] class AvroFileFormat extends FileFormat +private[sql] class AvroFileFormat extends FileFormat with DataSourceRegister with Logging with Serializable { override def equals(other: Any): Boolean = other match { @@ -56,74 +54,7 @@ private[avro] class AvroFileFormat extends FileFormat spark: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val conf = spark.sessionState.newHadoopConf() - if (options.contains("ignoreExtension")) { - logWarning(s"Option ${AvroOptions.ignoreExtensionKey} is deprecated. Please use the " + - "general data source option pathGlobFilter for filtering file names.") - } - val parsedOptions = new AvroOptions(options, conf) - - // User can specify an optional avro json schema. - val avroSchema = parsedOptions.schema - .map(new Schema.Parser().parse) - .getOrElse { - inferAvroSchemaFromFiles(files, conf, parsedOptions.ignoreExtension, - spark.sessionState.conf.ignoreCorruptFiles) - } - - SchemaConverters.toSqlType(avroSchema).dataType match { - case t: StructType => Some(t) - case _ => throw new RuntimeException( - s"""Avro schema cannot be converted to a Spark SQL StructType: - | - |${avroSchema.toString(true)} - |""".stripMargin) - } - } - - private def inferAvroSchemaFromFiles( - files: Seq[FileStatus], - conf: Configuration, - ignoreExtension: Boolean, - ignoreCorruptFiles: Boolean): Schema = { - // Schema evolution is not supported yet. Here we only pick first random readable sample file to - // figure out the schema of the whole dataset. - val avroReader = files.iterator.map { f => - val path = f.getPath - if (!ignoreExtension && !path.getName.endsWith(".avro")) { - None - } else { - Utils.tryWithResource { - new FsInput(path, conf) - } { in => - try { - Some(DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]())) - } catch { - case e: IOException => - if (ignoreCorruptFiles) { - logWarning(s"Skipped the footer in the corrupted file: $path", e) - None - } else { - throw new SparkException(s"Could not read file: $path", e) - } - } - } - } - }.collectFirst { - case Some(reader) => reader - } - - avroReader match { - case Some(reader) => - try { - reader.getSchema - } finally { - reader.close() - } - case None => - throw new FileNotFoundException( - "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") - } + AvroUtils.inferSchema(spark, options, files) } override def shortName(): String = "avro" @@ -140,32 +71,7 @@ private[avro] class AvroFileFormat extends FileFormat job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) - val outputAvroSchema: Schema = parsedOptions.schema - .map(new Schema.Parser().parse) - .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false, - parsedOptions.recordName, parsedOptions.recordNamespace)) - - AvroJob.setOutputKeySchema(job, outputAvroSchema) - - if (parsedOptions.compression == "uncompressed") { - job.getConfiguration.setBoolean("mapred.output.compress", false) - } else { - job.getConfiguration.setBoolean("mapred.output.compress", true) - logInfo(s"Compressing Avro output using the ${parsedOptions.compression} codec") - val codec = parsedOptions.compression match { - case DEFLATE_CODEC => - val deflateLevel = spark.sessionState.conf.avroDeflateLevel - logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") - job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) - DEFLATE_CODEC - case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC) => codec - case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") - } - job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) - } - - new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) + AvroUtils.prepareWrite(spark.sessionState.conf, job, options, dataSchema) } override def buildReader( @@ -250,22 +156,7 @@ private[avro] class AvroFileFormat extends FileFormat } } - override def supportDataType(dataType: DataType): Boolean = dataType match { - case _: AtomicType => true - - case st: StructType => st.forall { f => supportDataType(f.dataType) } - - case ArrayType(elementType, _) => supportDataType(elementType) - - case MapType(keyType, valueType, _) => - supportDataType(keyType) && supportDataType(valueType) - - case udt: UserDefinedType[_] => supportDataType(udt.sqlType) - - case _: NullType => true - - case _ => false - } + override def supportDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType) } private[avro] object AvroFileFormat { diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala index 116020ed5c43..0074044544c0 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriterFactory.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType * @param catalystSchema Catalyst schema of input data. * @param avroSchemaAsJsonString Avro schema of output result, in JSON string format. */ -private[avro] class AvroOutputWriterFactory( +private[sql] class AvroOutputWriterFactory( catalystSchema: StructType, avroSchemaAsJsonString: String) extends OutputWriterFactory { diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala new file mode 100644 index 000000000000..b978b7974b92 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.avro + +import java.io.{FileNotFoundException, IOException} + +import org.apache.avro.Schema +import org.apache.avro.file.DataFileConstants.{BZIP2_CODEC, DEFLATE_CODEC, SNAPPY_CODEC, XZ_CODEC} +import org.apache.avro.file.DataFileReader +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.{AvroOutputFormat, FsInput} +import org.apache.avro.mapreduce.AvroJob +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.OutputWriterFactory +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +object AvroUtils extends Logging { + def inferSchema( + spark: SparkSession, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = { + val conf = spark.sessionState.newHadoopConf() + if (options.contains("ignoreExtension")) { + logWarning(s"Option ${AvroOptions.ignoreExtensionKey} is deprecated. Please use the " + + "general data source option pathGlobFilter for filtering file names.") + } + val parsedOptions = new AvroOptions(options, conf) + + // User can specify an optional avro json schema. + val avroSchema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse { + inferAvroSchemaFromFiles(files, conf, parsedOptions.ignoreExtension, + spark.sessionState.conf.ignoreCorruptFiles) + } + + SchemaConverters.toSqlType(avroSchema).dataType match { + case t: StructType => Some(t) + case _ => throw new RuntimeException( + s"""Avro schema cannot be converted to a Spark SQL StructType: + | + |${avroSchema.toString(true)} + |""".stripMargin) + } + } + + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _: NullType => true + + case _ => false + } + + def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + val parsedOptions = new AvroOptions(options, job.getConfiguration) + val outputAvroSchema: Schema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false, + parsedOptions.recordName, parsedOptions.recordNamespace)) + + AvroJob.setOutputKeySchema(job, outputAvroSchema) + + if (parsedOptions.compression == "uncompressed") { + job.getConfiguration.setBoolean("mapred.output.compress", false) + } else { + job.getConfiguration.setBoolean("mapred.output.compress", true) + logInfo(s"Compressing Avro output using the ${parsedOptions.compression} codec") + val codec = parsedOptions.compression match { + case DEFLATE_CODEC => + val deflateLevel = sqlConf.avroDeflateLevel + logInfo(s"Avro compression level $deflateLevel will be used for $DEFLATE_CODEC codec.") + job.getConfiguration.setInt(AvroOutputFormat.DEFLATE_LEVEL_KEY, deflateLevel) + DEFLATE_CODEC + case codec @ (SNAPPY_CODEC | BZIP2_CODEC | XZ_CODEC) => codec + case unknown => throw new IllegalArgumentException(s"Invalid compression codec: $unknown") + } + job.getConfiguration.set(AvroJob.CONF_OUTPUT_CODEC, codec) + } + + new AvroOutputWriterFactory(dataSchema, outputAvroSchema.toString) + } + + private def inferAvroSchemaFromFiles( + files: Seq[FileStatus], + conf: Configuration, + ignoreExtension: Boolean, + ignoreCorruptFiles: Boolean): Schema = { + // Schema evolution is not supported yet. Here we only pick first random readable sample file to + // figure out the schema of the whole dataset. + val avroReader = files.iterator.map { f => + val path = f.getPath + if (!ignoreExtension && !path.getName.endsWith(".avro")) { + None + } else { + Utils.tryWithResource { + new FsInput(path, conf) + } { in => + try { + Some(DataFileReader.openReader(in, new GenericDatumReader[GenericRecord]())) + } catch { + case e: IOException => + if (ignoreCorruptFiles) { + logWarning(s"Skipped the footer in the corrupted file: $path", e) + None + } else { + throw new SparkException(s"Could not read file: $path", e) + } + } + } + } + }.collectFirst { + case Some(reader) => reader + } + + avroReader match { + case Some(reader) => + try { + reader.getSchema + } finally { + reader.close() + } + case None => + throw new FileNotFoundException( + "No Avro files found. If files don't have .avro extension, set ignoreExtension to true") + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala new file mode 100644 index 000000000000..3171f1e08b4f --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroDataSourceV2.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.v2.avro + +import org.apache.spark.sql.avro.AvroFileFormat +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class AvroDataSourceV2 extends FileDataSourceV2 { + + override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[AvroFileFormat] + + override def shortName(): String = "avro" + + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + AvroTable(tableName, sparkSession, options, paths, None, fallbackFileFormat) + } + + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + AvroTable(tableName, sparkSession, options, paths, Some(schema), fallbackFileFormat) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala new file mode 100644 index 000000000000..243af7da4700 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.v2.avro + +import java.net.URI + +import scala.util.control.NonFatal + +import org.apache.avro.Schema +import org.apache.avro.file.DataFileReader +import org.apache.avro.generic.{GenericDatumReader, GenericRecord} +import org.apache.avro.mapred.FsInput +import org.apache.hadoop.fs.Path + +import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory, PartitionReaderWithPartitionValues} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.reader.PartitionReader +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration + +/** + * A factory used to create AVRO readers. + * + * @param sqlConf SQL configuration. + * @param broadcastedConf Broadcast serializable Hadoop Configuration. + * @param dataSchema Schema of AVRO files. + * @param readDataSchema Required data schema of AVRO files. + * @param partitionSchema Schema of partitions. + * @param options Options for parsing AVRO files. + */ +case class AvroPartitionReaderFactory( + sqlConf: SQLConf, + broadcastedConf: Broadcast[SerializableConfiguration], + dataSchema: StructType, + readDataSchema: StructType, + partitionSchema: StructType, + options: Map[String, String]) extends FilePartitionReaderFactory with Logging { + + override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { + val conf = broadcastedConf.value.value + val parsedOptions = new AvroOptions(options, conf) + val userProvidedSchema = parsedOptions.schema.map(new Schema.Parser().parse) + + if (parsedOptions.ignoreExtension || partitionedFile.filePath.endsWith(".avro")) { + val reader = { + val in = new FsInput(new Path(new URI(partitionedFile.filePath)), conf) + try { + val datumReader = userProvidedSchema match { + case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema) + case _ => new GenericDatumReader[GenericRecord]() + } + DataFileReader.openReader(in, datumReader) + } catch { + case NonFatal(e) => + logError("Exception while opening DataFileReader", e) + in.close() + throw e + } + } + + // Ensure that the reader is closed even if the task fails or doesn't consume the entire + // iterator of records. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener[Unit] { _ => + reader.close() + } + } + + reader.sync(partitionedFile.start) + val stop = partitionedFile.start + partitionedFile.length + + val deserializer = + new AvroDeserializer(userProvidedSchema.getOrElse(reader.getSchema), readDataSchema) + + val fileReader = new PartitionReader[InternalRow] { + private[this] var completed = false + + override def next(): Boolean = { + if (completed) { + false + } else { + val r = reader.hasNext && !reader.pastSync(stop) + if (!r) { + reader.close() + completed = true + } + r + } + } + + override def get(): InternalRow = { + if (!next) { + throw new NoSuchElementException("next on empty iterator") + } + val record = reader.next() + deserializer.deserialize(record).asInstanceOf[InternalRow] + } + + override def close(): Unit = reader.close() + } + new PartitionReaderWithPartitionValues(fileReader, readDataSchema, + partitionSchema, partitionedFile.partitionValues) + } else { + new EmptyPartitionReader[InternalRow] + } + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala new file mode 100644 index 000000000000..6ec351080a11 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.v2.avro + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScan +import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration + +case class AvroScan( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, + readDataSchema: StructType, + readPartitionSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScan(sparkSession, fileIndex, readDataSchema, readPartitionSchema) { + override def isSplitable(path: Path): Boolean = true + + override def createReaderFactory(): PartitionReaderFactory = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + val broadcastedConf = sparkSession.sparkContext.broadcast( + new SerializableConfiguration(hadoopConf)) + // The partition values are already truncated in `FileScan.partitions`. + // We should use `readPartitionSchema` as the partition schema here. + AvroPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, + dataSchema, readDataSchema, readPartitionSchema, caseSensitiveMap) + } + } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala new file mode 100644 index 000000000000..815da2bd92d4 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.v2.avro + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder +import org.apache.spark.sql.sources.v2.reader.Scan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class AvroScanBuilder ( + sparkSession: SparkSession, + fileIndex: PartitioningAwareFileIndex, + schema: StructType, + dataSchema: StructType, + options: CaseInsensitiveStringMap) + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { + override def build(): Scan = { + AvroScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options) + } +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala new file mode 100644 index 000000000000..a781624aa61a --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.v2.avro + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.FileStatus + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.avro.AvroUtils +import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.datasources.v2.FileTable +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class AvroTable( + name: String, + sparkSession: SparkSession, + options: CaseInsensitiveStringMap, + paths: Seq[String], + userSpecifiedSchema: Option[StructType], + fallbackFileFormat: Class[_ <: FileFormat]) + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + override def newScanBuilder(options: CaseInsensitiveStringMap): AvroScanBuilder = + new AvroScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) + + override def inferSchema(files: Seq[FileStatus]): Option[StructType] = + AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files) + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = + new AvroWriteBuilder(options, paths, formatName, supportsDataType) + + override def supportsDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType) + + override def formatName: String = "AVRO" +} diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWriteBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWriteBuilder.scala new file mode 100644 index 000000000000..c2ddc4b19127 --- /dev/null +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWriteBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.v2.avro + +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.sql.avro.AvroUtils +import org.apache.spark.sql.execution.datasources.OutputWriterFactory +import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class AvroWriteBuilder( + options: CaseInsensitiveStringMap, + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean) + extends FileWriteBuilder(options, paths, formatName, supportsDataType) { + override def prepareWrite( + sqlConf: SQLConf, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + AvroUtils.prepareWrite(sqlConf, job, options, dataSchema) + } +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 79ba2871c226..96382764b053 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -24,14 +24,14 @@ import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types.{StructField, StructType, TimestampType} -class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +abstract class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ val dateSchema = s""" @@ -349,3 +349,19 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU } } } + +class AvroV1LogicalTypeSuite extends AvroLogicalTypeSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "avro") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "avro") +} + +class AvroV2LogicalTypeSuite extends AvroLogicalTypeSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "") +} diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 49aa21884f8b..40bf3b1530fb 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -33,7 +33,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.TestingUDT.{IntervalData, NullData, NullUDT} import org.apache.spark.sql.execution.datasources.DataSource @@ -42,7 +42,7 @@ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +abstract class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { import testImplicits._ val episodesAvro = testFile("episodes.avro") @@ -81,7 +81,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("resolve avro data source") { val databricksAvro = "com.databricks.spark.avro" // By default the backward compatibility for com.databricks.spark.avro is enabled. - Seq("avro", "org.apache.spark.sql.avro.AvroFileFormat", databricksAvro).foreach { provider => + Seq("org.apache.spark.sql.avro.AvroFileFormat", databricksAvro).foreach { provider => assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === classOf[org.apache.spark.sql.avro.AvroFileFormat]) } @@ -1000,7 +1000,8 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { var msg = intercept[AnalysisException] { sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir) }.getMessage - assert(msg.contains("Cannot save interval data type into external storage.")) + assert(msg.contains("Cannot save interval data type into external storage.") || + msg.contains("AVRO data source does not support calendarinterval data type.")) msg = intercept[AnalysisException] { spark.udf.register("testType", () => new IntervalData()) @@ -1492,3 +1493,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin) } } + +class AvroV1Suite extends AvroSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "avro") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "avro") +} + +class AvroV2Suite extends AvroSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "") +}