Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.FileSourceOptions
import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode}
import org.apache.spark.sql.internal.SQLConf

Expand All @@ -37,6 +37,8 @@ private[sql] class AvroOptions(
@transient val conf: Configuration)
extends FileSourceOptions(parameters) with Logging {

import AvroOptions._

def this(parameters: Map[String, String], conf: Configuration) = {
this(CaseInsensitiveMap(parameters), conf)
}
Expand All @@ -54,8 +56,8 @@ private[sql] class AvroOptions(
* instead of "string" type in the default converted schema.
*/
val schema: Option[Schema] = {
parameters.get("avroSchema").map(new Schema.Parser().setValidateDefaults(false).parse).orElse({
val avroUrlSchema = parameters.get("avroSchemaUrl").map(url => {
parameters.get(AVRO_SCHEMA).map(new Schema.Parser().setValidateDefaults(false).parse).orElse({
val avroUrlSchema = parameters.get(AVRO_SCHEMA_URL).map(url => {
log.debug("loading avro schema from url: " + url)
val fs = FileSystem.get(new URI(url), conf)
val in = fs.open(new Path(url))
Expand All @@ -75,20 +77,20 @@ private[sql] class AvroOptions(
* whose field names do not match. Defaults to false.
*/
val positionalFieldMatching: Boolean =
parameters.get("positionalFieldMatching").exists(_.toBoolean)
parameters.get(POSITIONAL_FIELD_MATCHING).exists(_.toBoolean)

/**
* Top level record name in write result, which is required in Avro spec.
* See https://avro.apache.org/docs/1.11.1/specification/#schema-record .
* Default value is "topLevelRecord"
*/
val recordName: String = parameters.getOrElse("recordName", "topLevelRecord")
val recordName: String = parameters.getOrElse(RECORD_NAME, "topLevelRecord")

/**
* Record namespace in write result. Default value is "".
* See Avro spec for details: https://avro.apache.org/docs/1.11.1/specification/#schema-record .
*/
val recordNamespace: String = parameters.getOrElse("recordNamespace", "")
val recordNamespace: String = parameters.getOrElse(RECORD_NAMESPACE, "")

/**
* The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read.
Expand All @@ -104,7 +106,7 @@ private[sql] class AvroOptions(
ignoreFilesWithoutExtensionByDefault)

parameters
.get(AvroOptions.ignoreExtensionKey)
.get(IGNORE_EXTENSION)
.map(_.toBoolean)
.getOrElse(!ignoreFilesWithoutExtension)
}
Expand All @@ -116,21 +118,21 @@ private[sql] class AvroOptions(
* taken into account. If the former one is not set too, the `snappy` codec is used by default.
*/
val compression: String = {
parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec)
parameters.get(COMPRESSION).getOrElse(SQLConf.get.avroCompressionCodec)
}

val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode)
parameters.get(MODE).map(ParseMode.fromString).getOrElse(FailFastMode)

/**
* The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads.
*/
val datetimeRebaseModeInRead: String = parameters
.get(AvroOptions.DATETIME_REBASE_MODE)
.get(DATETIME_REBASE_MODE)
.getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ))
}

private[sql] object AvroOptions {
private[sql] object AvroOptions extends DataSourceOptions {
def apply(parameters: Map[String, String]): AvroOptions = {
val hadoopConf = SparkSession
.getActiveSession
Expand All @@ -139,11 +141,17 @@ private[sql] object AvroOptions {
new AvroOptions(CaseInsensitiveMap(parameters), hadoopConf)
}

val ignoreExtensionKey = "ignoreExtension"

val IGNORE_EXTENSION = newOption("ignoreExtension")
val MODE = newOption("mode")
val RECORD_NAME = newOption("recordName")
val COMPRESSION = newOption("compression")
val AVRO_SCHEMA = newOption("avroSchema")
val AVRO_SCHEMA_URL = newOption("avroSchemaUrl")
val RECORD_NAMESPACE = newOption("recordNamespace")
val POSITIONAL_FIELD_MATCHING = newOption("positionalFieldMatching")
// The option controls rebasing of the DATE and TIMESTAMP values between
// Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Avro
// datasource similarly to the SQL config `spark.sql.avro.datetimeRebaseModeInRead`,
// and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`.
val DATETIME_REBASE_MODE = "datetimeRebaseMode"
val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode")
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.AvroOptions.ignoreExtensionKey
import org.apache.spark.sql.avro.AvroOptions.IGNORE_EXTENSION
import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.OutputWriterFactory
Expand All @@ -50,8 +50,8 @@ private[sql] object AvroUtils extends Logging {
val conf = spark.sessionState.newHadoopConfWithOptions(options)
val parsedOptions = new AvroOptions(options, conf)

if (parsedOptions.parameters.contains(ignoreExtensionKey)) {
logWarning(s"Option $ignoreExtensionKey is deprecated. Please use the " +
if (parsedOptions.parameters.contains(IGNORE_EXTENSION)) {
logWarning(s"Option $IGNORE_EXTENSION is deprecated. Please use the " +
"general data source option pathGlobFilter for filtering file names.")
}
// User can specify an optional avro json schema.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1804,13 +1804,13 @@ abstract class AvroSuite
spark
.read
.format("avro")
.option(AvroOptions.ignoreExtensionKey, false)
.option(AvroOptions.IGNORE_EXTENSION, false)
.load(dir.getCanonicalPath)
.count()
}
val deprecatedEvents = logAppender.loggingEvents
.filter(_.getMessage.getFormattedMessage.contains(
s"Option ${AvroOptions.ignoreExtensionKey} is deprecated"))
s"Option ${AvroOptions.IGNORE_EXTENSION} is deprecated"))
assert(deprecatedEvents.size === 1)
}
}
Expand Down Expand Up @@ -2272,6 +2272,20 @@ abstract class AvroSuite
checkAnswer(df2, df.collect().toSeq)
}
}

test("SPARK-40667: validate Avro Options") {
assert(AvroOptions.getAllOptions.size == 9)
// Please add validation on any new Avro options here
assert(AvroOptions.isValidOption("ignoreExtension"))
assert(AvroOptions.isValidOption("mode"))
assert(AvroOptions.isValidOption("recordName"))
assert(AvroOptions.isValidOption("compression"))
assert(AvroOptions.isValidOption("avroSchema"))
assert(AvroOptions.isValidOption("avroSchemaUrl"))
assert(AvroOptions.isValidOption("recordNamespace"))
assert(AvroOptions.isValidOption("positionalFieldMatching"))
assert(AvroOptions.isValidOption("datetimeRebaseMode"))
}
}

class AvroV1Suite extends AvroSuite {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

/**
* Interface defines the following methods for a data source:
* - register a new option name
* - retrieve all registered option names
* - valid a given option name
* - get alternative option name if any
*/
trait DataSourceOptions {
// Option -> Alternative Option if any
private val validOptions = collection.mutable.Map[String, Option[String]]()

/**
* Register a new Option.
*/
protected def newOption(name: String): String = {
validOptions += (name -> None)
name
}

/**
* Register a new Option with an alternative name.
* @param name Option name
* @param alternative Alternative option name
*/
protected def newOption(name: String, alternative: String): Unit = {
// Register both of the options
Copy link
Contributor

@cloud-fan cloud-fan Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how can we know which one is the primary one? for example,

  val charset = parameters.getOrElse(ENCODING,
    parameters.getOrElse(CHARSET, StandardCharsets.UTF_8.name()))

ENCODING is the primary one as it will be respected if both are set.

Copy link
Contributor Author

@xiaonanyang-db xiaonanyang-db Oct 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't really care about which one is primary here, the reason we want to track alternative options is that callers may want to provide an error / log a warning if both of the alternative options are provided. Which one will be respected could be decided by the caller.

validOptions += (name -> Some(alternative))
validOptions += (alternative -> Some(name))
}

/**
* @return All data source options and their alternatives if any
*/
def getAllOptions: scala.collection.Set[String] = validOptions.keySet

/**
* @param name Option name to be validated
* @return if the given Option name is valid
*/
def isValidOption(name: String): Boolean = validOptions.contains(name)

/**
* @param name Option name
* @return Alternative option name if any
*/
def getAlternativeOption(name: String): Option[String] = validOptions.get(name).flatten
}
Loading