diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index e065bbce27094..95001bb81508c 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} import org.apache.spark.sql.internal.SQLConf @@ -37,6 +37,8 @@ private[sql] class AvroOptions( @transient val conf: Configuration) extends FileSourceOptions(parameters) with Logging { + import AvroOptions._ + def this(parameters: Map[String, String], conf: Configuration) = { this(CaseInsensitiveMap(parameters), conf) } @@ -54,8 +56,8 @@ private[sql] class AvroOptions( * instead of "string" type in the default converted schema. */ val schema: Option[Schema] = { - parameters.get("avroSchema").map(new Schema.Parser().setValidateDefaults(false).parse).orElse({ - val avroUrlSchema = parameters.get("avroSchemaUrl").map(url => { + parameters.get(AVRO_SCHEMA).map(new Schema.Parser().setValidateDefaults(false).parse).orElse({ + val avroUrlSchema = parameters.get(AVRO_SCHEMA_URL).map(url => { log.debug("loading avro schema from url: " + url) val fs = FileSystem.get(new URI(url), conf) val in = fs.open(new Path(url)) @@ -75,20 +77,20 @@ private[sql] class AvroOptions( * whose field names do not match. Defaults to false. */ val positionalFieldMatching: Boolean = - parameters.get("positionalFieldMatching").exists(_.toBoolean) + parameters.get(POSITIONAL_FIELD_MATCHING).exists(_.toBoolean) /** * Top level record name in write result, which is required in Avro spec. * See https://avro.apache.org/docs/1.11.1/specification/#schema-record . * Default value is "topLevelRecord" */ - val recordName: String = parameters.getOrElse("recordName", "topLevelRecord") + val recordName: String = parameters.getOrElse(RECORD_NAME, "topLevelRecord") /** * Record namespace in write result. Default value is "". * See Avro spec for details: https://avro.apache.org/docs/1.11.1/specification/#schema-record . */ - val recordNamespace: String = parameters.getOrElse("recordNamespace", "") + val recordNamespace: String = parameters.getOrElse(RECORD_NAMESPACE, "") /** * The `ignoreExtension` option controls ignoring of files without `.avro` extensions in read. @@ -104,7 +106,7 @@ private[sql] class AvroOptions( ignoreFilesWithoutExtensionByDefault) parameters - .get(AvroOptions.ignoreExtensionKey) + .get(IGNORE_EXTENSION) .map(_.toBoolean) .getOrElse(!ignoreFilesWithoutExtension) } @@ -116,21 +118,21 @@ private[sql] class AvroOptions( * taken into account. If the former one is not set too, the `snappy` codec is used by default. */ val compression: String = { - parameters.get("compression").getOrElse(SQLConf.get.avroCompressionCodec) + parameters.get(COMPRESSION).getOrElse(SQLConf.get.avroCompressionCodec) } val parseMode: ParseMode = - parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) + parameters.get(MODE).map(ParseMode.fromString).getOrElse(FailFastMode) /** * The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads. */ val datetimeRebaseModeInRead: String = parameters - .get(AvroOptions.DATETIME_REBASE_MODE) + .get(DATETIME_REBASE_MODE) .getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ)) } -private[sql] object AvroOptions { +private[sql] object AvroOptions extends DataSourceOptions { def apply(parameters: Map[String, String]): AvroOptions = { val hadoopConf = SparkSession .getActiveSession @@ -139,11 +141,17 @@ private[sql] object AvroOptions { new AvroOptions(CaseInsensitiveMap(parameters), hadoopConf) } - val ignoreExtensionKey = "ignoreExtension" - + val IGNORE_EXTENSION = newOption("ignoreExtension") + val MODE = newOption("mode") + val RECORD_NAME = newOption("recordName") + val COMPRESSION = newOption("compression") + val AVRO_SCHEMA = newOption("avroSchema") + val AVRO_SCHEMA_URL = newOption("avroSchemaUrl") + val RECORD_NAMESPACE = newOption("recordNamespace") + val POSITIONAL_FIELD_MATCHING = newOption("positionalFieldMatching") // The option controls rebasing of the DATE and TIMESTAMP values between // Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Avro // datasource similarly to the SQL config `spark.sql.avro.datetimeRebaseModeInRead`, // and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`. - val DATETIME_REBASE_MODE = "datetimeRebaseMode" + val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode") } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 56d177da14369..45fa7450e4522 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.avro.AvroOptions.ignoreExtensionKey +import org.apache.spark.sql.avro.AvroOptions.IGNORE_EXTENSION import org.apache.spark.sql.catalyst.{FileSourceOptions, InternalRow} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.OutputWriterFactory @@ -50,8 +50,8 @@ private[sql] object AvroUtils extends Logging { val conf = spark.sessionState.newHadoopConfWithOptions(options) val parsedOptions = new AvroOptions(options, conf) - if (parsedOptions.parameters.contains(ignoreExtensionKey)) { - logWarning(s"Option $ignoreExtensionKey is deprecated. Please use the " + + if (parsedOptions.parameters.contains(IGNORE_EXTENSION)) { + logWarning(s"Option $IGNORE_EXTENSION is deprecated. Please use the " + "general data source option pathGlobFilter for filtering file names.") } // User can specify an optional avro json schema. diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index cf4a490b90273..a02bb067dcc4b 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1804,13 +1804,13 @@ abstract class AvroSuite spark .read .format("avro") - .option(AvroOptions.ignoreExtensionKey, false) + .option(AvroOptions.IGNORE_EXTENSION, false) .load(dir.getCanonicalPath) .count() } val deprecatedEvents = logAppender.loggingEvents .filter(_.getMessage.getFormattedMessage.contains( - s"Option ${AvroOptions.ignoreExtensionKey} is deprecated")) + s"Option ${AvroOptions.IGNORE_EXTENSION} is deprecated")) assert(deprecatedEvents.size === 1) } } @@ -2272,6 +2272,20 @@ abstract class AvroSuite checkAnswer(df2, df.collect().toSeq) } } + + test("SPARK-40667: validate Avro Options") { + assert(AvroOptions.getAllOptions.size == 9) + // Please add validation on any new Avro options here + assert(AvroOptions.isValidOption("ignoreExtension")) + assert(AvroOptions.isValidOption("mode")) + assert(AvroOptions.isValidOption("recordName")) + assert(AvroOptions.isValidOption("compression")) + assert(AvroOptions.isValidOption("avroSchema")) + assert(AvroOptions.isValidOption("avroSchemaUrl")) + assert(AvroOptions.isValidOption("recordNamespace")) + assert(AvroOptions.isValidOption("positionalFieldMatching")) + assert(AvroOptions.isValidOption("datetimeRebaseMode")) + } } class AvroV1Suite extends AvroSuite { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala new file mode 100644 index 0000000000000..5348d1054d5d4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DataSourceOptions.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +/** + * Interface defines the following methods for a data source: + * - register a new option name + * - retrieve all registered option names + * - valid a given option name + * - get alternative option name if any + */ +trait DataSourceOptions { + // Option -> Alternative Option if any + private val validOptions = collection.mutable.Map[String, Option[String]]() + + /** + * Register a new Option. + */ + protected def newOption(name: String): String = { + validOptions += (name -> None) + name + } + + /** + * Register a new Option with an alternative name. + * @param name Option name + * @param alternative Alternative option name + */ + protected def newOption(name: String, alternative: String): Unit = { + // Register both of the options + validOptions += (name -> Some(alternative)) + validOptions += (alternative -> Some(name)) + } + + /** + * @return All data source options and their alternatives if any + */ + def getAllOptions: scala.collection.Set[String] = validOptions.keySet + + /** + * @param name Option name to be validated + * @return if the given Option name is valid + */ + def isValidOption(name: String): Boolean = validOptions.contains(name) + + /** + * @param name Option name + * @return Alternative option name if any + */ + def getAlternativeOption(name: String): Option[String] = validOptions.get(name).flatten +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 88396c65cc070..a66070aa853d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -24,7 +24,7 @@ import java.util.Locale import com.univocity.parsers.csv.{CsvParserSettings, CsvWriterSettings, UnescapedQuoteHandling} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -37,6 +37,8 @@ class CSVOptions( defaultColumnNameOfCorruptRecord: String) extends FileSourceOptions(parameters) with Logging { + import CSVOptions._ + def this( parameters: Map[String, String], columnPruning: Boolean, @@ -99,46 +101,46 @@ class CSVOptions( } val delimiter = CSVExprUtils.toDelimiterStr( - parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) + parameters.getOrElse(SEP, parameters.getOrElse(DELIMITER, ","))) val parseMode: ParseMode = - parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) - val charset = parameters.getOrElse("encoding", - parameters.getOrElse("charset", StandardCharsets.UTF_8.name())) + parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode) + val charset = parameters.getOrElse(ENCODING, + parameters.getOrElse(CHARSET, StandardCharsets.UTF_8.name())) - val quote = getChar("quote", '\"') - val escape = getChar("escape", '\\') - val charToEscapeQuoteEscaping = parameters.get("charToEscapeQuoteEscaping") match { + val quote = getChar(QUOTE, '\"') + val escape = getChar(ESCAPE, '\\') + val charToEscapeQuoteEscaping = parameters.get(CHAR_TO_ESCAPE_QUOTE_ESCAPING) match { case None => None case Some(null) => None case Some(value) if value.length == 0 => None case Some(value) if value.length == 1 => Some(value.charAt(0)) - case _ => throw QueryExecutionErrors.paramExceedOneCharError("charToEscapeQuoteEscaping") + case _ => throw QueryExecutionErrors.paramExceedOneCharError(CHAR_TO_ESCAPE_QUOTE_ESCAPING) } - val comment = getChar("comment", '\u0000') + val comment = getChar(COMMENT, '\u0000') - val headerFlag = getBool("header") - val inferSchemaFlag = getBool("inferSchema") - val ignoreLeadingWhiteSpaceInRead = getBool("ignoreLeadingWhiteSpace", default = false) - val ignoreTrailingWhiteSpaceInRead = getBool("ignoreTrailingWhiteSpace", default = false) + val headerFlag = getBool(HEADER) + val inferSchemaFlag = getBool(INFER_SCHEMA) + val ignoreLeadingWhiteSpaceInRead = getBool(IGNORE_LEADING_WHITESPACE, default = false) + val ignoreTrailingWhiteSpaceInRead = getBool(IGNORE_TRAILING_WHITESPACE, default = false) // For write, both options were `true` by default. We leave it as `true` for // backwards compatibility. - val ignoreLeadingWhiteSpaceFlagInWrite = getBool("ignoreLeadingWhiteSpace", default = true) - val ignoreTrailingWhiteSpaceFlagInWrite = getBool("ignoreTrailingWhiteSpace", default = true) + val ignoreLeadingWhiteSpaceFlagInWrite = getBool(IGNORE_LEADING_WHITESPACE, default = true) + val ignoreTrailingWhiteSpaceFlagInWrite = getBool(IGNORE_TRAILING_WHITESPACE, default = true) val columnNameOfCorruptRecord = - parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + parameters.getOrElse(COLUMN_NAME_OF_CORRUPT_RECORD, defaultColumnNameOfCorruptRecord) - val nullValue = parameters.getOrElse("nullValue", "") + val nullValue = parameters.getOrElse(NULL_VALUE, "") - val nanValue = parameters.getOrElse("nanValue", "NaN") + val nanValue = parameters.getOrElse(NAN_VALUE, "NaN") - val positiveInf = parameters.getOrElse("positiveInf", "Inf") - val negativeInf = parameters.getOrElse("negativeInf", "-Inf") + val positiveInf = parameters.getOrElse(POSITIVE_INF, "Inf") + val negativeInf = parameters.getOrElse(NEGATIVE_INF, "-Inf") val compressionCodec: Option[String] = { - val name = parameters.get("compression").orElse(parameters.get("codec")) + val name = parameters.get(COMPRESSION).orElse(parameters.get(CODEC)) name.map(CompressionCodecs.getCodecClassName) } @@ -146,7 +148,7 @@ class CSVOptions( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) // A language tag in IETF BCP 47 format - val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val locale: Locale = parameters.get(LOCALE).map(Locale.forLanguageTag).getOrElse(Locale.US) /** * Infer columns with all valid date entries as date type (otherwise inferred as string or @@ -161,11 +163,11 @@ class CSVOptions( if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { false } else { - getBool("prefersDate", true) + getBool(PREFERS_DATE, true) } } - val dateFormatOption: Option[String] = parameters.get("dateFormat") + val dateFormatOption: Option[String] = parameters.get(DATE_FORMAT) // Provide a default value for dateFormatInRead when prefersDate. This ensures that the // Iso8601DateFormatter (with strict date parsing) is used for date inference val dateFormatInRead: Option[String] = @@ -174,24 +176,24 @@ class CSVOptions( } else { dateFormatOption } - val dateFormatInWrite: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) + val dateFormatInWrite: String = parameters.getOrElse(DATE_FORMAT, DateFormatter.defaultPattern) val timestampFormatInRead: Option[String] = if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { - Some(parameters.getOrElse("timestampFormat", + Some(parameters.getOrElse(TIMESTAMP_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX")) } else { - parameters.get("timestampFormat") + parameters.get(TIMESTAMP_FORMAT) } - val timestampFormatInWrite: String = parameters.getOrElse("timestampFormat", + val timestampFormatInWrite: String = parameters.getOrElse(TIMESTAMP_FORMAT, if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX" } else { s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" }) - val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat") - val timestampNTZFormatInWrite: String = parameters.getOrElse("timestampNTZFormat", + val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT) + val timestampNTZFormatInWrite: String = parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") // SPARK-39731: Enables the backward compatible parsing behavior. @@ -203,17 +205,17 @@ class CSVOptions( // Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and // the value will be parsed as null. val enableDateTimeParsingFallback: Option[Boolean] = - parameters.get("enableDateTimeParsingFallback").map(_.toBoolean) + parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean) - val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get(MULTI_LINE).map(_.toBoolean).getOrElse(false) - val maxColumns = getInt("maxColumns", 20480) + val maxColumns = getInt(MAX_COLUMNS, 20480) - val maxCharsPerColumn = getInt("maxCharsPerColumn", -1) + val maxCharsPerColumn = getInt(MAX_CHARS_PER_COLUMN, -1) - val escapeQuotes = getBool("escapeQuotes", true) + val escapeQuotes = getBool(ESCAPE_QUOTES, true) - val quoteAll = getBool("quoteAll", false) + val quoteAll = getBool(QUOTE_ALL, false) /** * The max error content length in CSV parser/writer exception message. @@ -223,18 +225,18 @@ class CSVOptions( val isCommentSet = this.comment != '\u0000' val samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0) /** * Forcibly apply the specified or inferred schema to datasource files. * If the option is enabled, headers of CSV files will be ignored. */ - val enforceSchema = getBool("enforceSchema", default = true) + val enforceSchema = getBool(ENFORCE_SCHEMA, default = true) /** * String representation of an empty value in read and in write. */ - val emptyValue = parameters.get("emptyValue") + val emptyValue = parameters.get(EMPTY_VALUE) /** * The string is returned when CSV reader doesn't have any characters for input value, * or an empty quoted string `""`. Default value is empty string. @@ -248,7 +250,7 @@ class CSVOptions( /** * A string between two consecutive JSON records. */ - val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { sep => require(sep.nonEmpty, "'lineSep' cannot be an empty string.") // Intentionally allow it up to 2 for Window's CRLF although multiple // characters have an issue with quotes. This is intentionally undocumented. @@ -263,14 +265,14 @@ class CSVOptions( } val lineSeparatorInWrite: Option[String] = lineSeparator - val inputBufferSize: Option[Int] = parameters.get("inputBufferSize").map(_.toInt) + val inputBufferSize: Option[Int] = parameters.get(INPUT_BUFFER_SIZE).map(_.toInt) .orElse(SQLConf.get.getConf(SQLConf.CSV_INPUT_BUFFER_SIZE)) /** * The handling method to be used when unescaped quotes are found in the input. */ val unescapedQuoteHandling: UnescapedQuoteHandling = UnescapedQuoteHandling.valueOf(parameters - .getOrElse("unescapedQuoteHandling", "STOP_AT_DELIMITER").toUpperCase(Locale.ROOT)) + .getOrElse(UNESCAPED_QUOTE_HANDLING, "STOP_AT_DELIMITER").toUpperCase(Locale.ROOT)) def asWriterSettings: CsvWriterSettings = { val writerSettings = new CsvWriterSettings() @@ -327,3 +329,48 @@ class CSVOptions( settings } } + +object CSVOptions extends DataSourceOptions { + val HEADER = newOption("header") + val INFER_SCHEMA = newOption("inferSchema") + val IGNORE_LEADING_WHITESPACE = newOption("ignoreLeadingWhiteSpace") + val IGNORE_TRAILING_WHITESPACE = newOption("ignoreTrailingWhiteSpace") + val PREFERS_DATE = newOption("prefersDate") + val ESCAPE_QUOTES = newOption("escapeQuotes") + val QUOTE_ALL = newOption("quoteAll") + val ENFORCE_SCHEMA = newOption("enforceSchema") + val QUOTE = newOption("quote") + val ESCAPE = newOption("escape") + val COMMENT = newOption("comment") + val MAX_COLUMNS = newOption("maxColumns") + val MAX_CHARS_PER_COLUMN = newOption("maxCharsPerColumn") + val MODE = newOption("mode") + val CHAR_TO_ESCAPE_QUOTE_ESCAPING = newOption("charToEscapeQuoteEscaping") + val LOCALE = newOption("locale") + val DATE_FORMAT = newOption("dateFormat") + val TIMESTAMP_FORMAT = newOption("timestampFormat") + val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat") + val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback") + val MULTI_LINE = newOption("multiLine") + val SAMPLING_RATIO = newOption("samplingRatio") + val EMPTY_VALUE = newOption("emptyValue") + val LINE_SEP = newOption("lineSep") + val INPUT_BUFFER_SIZE = newOption("inputBufferSize") + val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord") + val NULL_VALUE = newOption("nullValue") + val NAN_VALUE = newOption("nanValue") + val POSITIVE_INF = newOption("positiveInf") + val NEGATIVE_INF = newOption("negativeInf") + val TIME_ZONE = newOption("timeZone") + val UNESCAPED_QUOTE_HANDLING = newOption("unescapedQuoteHandling") + // Options with alternative + val ENCODING = "encoding" + val CHARSET = "charset" + newOption(ENCODING, CHARSET) + val COMPRESSION = "compression" + val CODEC = "codec" + newOption(COMPRESSION, CODEC) + val SEP = "sep" + val DELIMITER = "delimiter" + newOption(SEP, DELIMITER) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 9679a60622bc9..bf5b83e9df0f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -25,7 +25,7 @@ import com.fasterxml.jackson.core.{JsonFactory, JsonFactoryBuilder} import com.fasterxml.jackson.core.json.JsonReadFeature import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy @@ -41,6 +41,8 @@ private[sql] class JSONOptions( defaultColumnNameOfCorruptRecord: String) extends FileSourceOptions(parameters) with Logging { + import JSONOptions._ + def this( parameters: Map[String, String], defaultTimeZoneId: String, @@ -52,36 +54,36 @@ private[sql] class JSONOptions( } val samplingRatio = - parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) + parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0) val primitivesAsString = - parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) + parameters.get(PRIMITIVES_AS_STRING).map(_.toBoolean).getOrElse(false) val prefersDecimal = - parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false) + parameters.get(PREFERS_DECIMAL).map(_.toBoolean).getOrElse(false) val allowComments = - parameters.get("allowComments").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_COMMENTS).map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = - parameters.get("allowUnquotedFieldNames").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_UNQUOTED_FIELD_NAMES).map(_.toBoolean).getOrElse(false) val allowSingleQuotes = - parameters.get("allowSingleQuotes").map(_.toBoolean).getOrElse(true) + parameters.get(ALLOW_SINGLE_QUOTES).map(_.toBoolean).getOrElse(true) val allowNumericLeadingZeros = - parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_NUMERIC_LEADING_ZEROS).map(_.toBoolean).getOrElse(false) val allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + parameters.get(ALLOW_NON_NUMERIC_NUMBERS).map(_.toBoolean).getOrElse(true) val allowBackslashEscapingAnyCharacter = - parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) + parameters.get(ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER).map(_.toBoolean).getOrElse(false) private val allowUnquotedControlChars = - parameters.get("allowUnquotedControlChars").map(_.toBoolean).getOrElse(false) - val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) + parameters.get(ALLOW_UNQUOTED_CONTROL_CHARS).map(_.toBoolean).getOrElse(false) + val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName) val parseMode: ParseMode = - parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode) + parameters.get(MODE).map(ParseMode.fromString).getOrElse(PermissiveMode) val columnNameOfCorruptRecord = - parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) + parameters.getOrElse(COLUMN_NAME_OF_CORRUPTED_RECORD, defaultColumnNameOfCorruptRecord) // Whether to ignore column of all null values or empty array/struct during schema inference - val dropFieldIfAllNull = parameters.get("dropFieldIfAllNull").map(_.toBoolean).getOrElse(false) + val dropFieldIfAllNull = parameters.get(DROP_FIELD_IF_ALL_NULL).map(_.toBoolean).getOrElse(false) // Whether to ignore null fields during json generating - val ignoreNullFields = parameters.get("ignoreNullFields").map(_.toBoolean) + val ignoreNullFields = parameters.get(IGNORE_NULL_FIELDS).map(_.toBoolean) .getOrElse(SQLConf.get.jsonGeneratorIgnoreNullFields) // If this is true, when writing NULL values to columns of JSON tables with explicit DEFAULT @@ -91,31 +93,31 @@ private[sql] class JSONOptions( val writeNullIfWithDefaultValue = SQLConf.get.jsonWriteNullIfWithDefaultValue // A language tag in IETF BCP 47 format - val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) + val locale: Locale = parameters.get(LOCALE).map(Locale.forLanguageTag).getOrElse(Locale.US) val zoneId: ZoneId = DateTimeUtils.getZoneId( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) - val dateFormatInRead: Option[String] = parameters.get("dateFormat") - val dateFormatInWrite: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) + val dateFormatInRead: Option[String] = parameters.get(DATE_FORMAT) + val dateFormatInWrite: String = parameters.getOrElse(DATE_FORMAT, DateFormatter.defaultPattern) val timestampFormatInRead: Option[String] = if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { - Some(parameters.getOrElse("timestampFormat", + Some(parameters.getOrElse(TIMESTAMP_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX")) } else { - parameters.get("timestampFormat") + parameters.get(TIMESTAMP_FORMAT) } - val timestampFormatInWrite: String = parameters.getOrElse("timestampFormat", + val timestampFormatInWrite: String = parameters.getOrElse(TIMESTAMP_FORMAT, if (SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY) { s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX" } else { s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" }) - val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat") + val timestampNTZFormatInRead: Option[String] = parameters.get(TIMESTAMP_NTZ_FORMAT) val timestampNTZFormatInWrite: String = - parameters.getOrElse("timestampNTZFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") + parameters.getOrElse(TIMESTAMP_NTZ_FORMAT, s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") // SPARK-39731: Enables the backward compatible parsing behavior. // Generally, this config should be set to false to avoid producing potentially incorrect results @@ -126,14 +128,14 @@ private[sql] class JSONOptions( // Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and // the value will be parsed as null. val enableDateTimeParsingFallback: Option[Boolean] = - parameters.get("enableDateTimeParsingFallback").map(_.toBoolean) + parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean) - val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) + val multiLine = parameters.get(MULTI_LINE).map(_.toBoolean).getOrElse(false) /** * A string between two consecutive JSON records. */ - val lineSeparator: Option[String] = parameters.get("lineSep").map { sep => + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { sep => require(sep.nonEmpty, "'lineSep' cannot be an empty string.") sep } @@ -146,8 +148,8 @@ private[sql] class JSONOptions( * when the multiLine option is set to `true`. If encoding is not specified in write, * UTF-8 is used by default. */ - val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map(checkedEncoding) + val encoding: Option[String] = parameters.get(ENCODING) + .orElse(parameters.get(CHARSET)).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse(StandardCharsets.UTF_8.name())) @@ -157,20 +159,20 @@ private[sql] class JSONOptions( /** * Generating JSON strings in pretty representation if the parameter is enabled. */ - val pretty: Boolean = parameters.get("pretty").map(_.toBoolean).getOrElse(false) + val pretty: Boolean = parameters.get(PRETTY).map(_.toBoolean).getOrElse(false) /** * Enables inferring of TimestampType and TimestampNTZType from strings matched to the * corresponding timestamp pattern defined by the timestampFormat and timestampNTZFormat options * respectively. */ - val inferTimestamp: Boolean = parameters.get("inferTimestamp").map(_.toBoolean).getOrElse(false) + val inferTimestamp: Boolean = parameters.get(INFER_TIMESTAMP).map(_.toBoolean).getOrElse(false) /** * Generating \u0000 style codepoints for non-ASCII characters if the parameter is enabled. */ val writeNonAsciiCharacterAsCodePoint: Boolean = - parameters.get("writeNonAsciiCharacterAsCodePoint").map(_.toBoolean).getOrElse(false) + parameters.get(WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT).map(_.toBoolean).getOrElse(false) /** Build a Jackson [[JsonFactory]] using JSON options. */ def buildJsonFactory(): JsonFactory = { @@ -230,3 +232,36 @@ private[sql] object JSONOptionsInRead { Charset.forName("UTF-32") ) } + +object JSONOptions extends DataSourceOptions { + val SAMPLING_RATIO = newOption("samplingRatio") + val PRIMITIVES_AS_STRING = newOption("primitivesAsString") + val PREFERS_DECIMAL = newOption("prefersDecimal") + val ALLOW_COMMENTS = newOption("allowComments") + val ALLOW_UNQUOTED_FIELD_NAMES = newOption("allowUnquotedFieldNames") + val ALLOW_SINGLE_QUOTES = newOption("allowSingleQuotes") + val ALLOW_NUMERIC_LEADING_ZEROS = newOption("allowNumericLeadingZeros") + val ALLOW_NON_NUMERIC_NUMBERS = newOption("allowNonNumericNumbers") + val ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER = newOption("allowBackslashEscapingAnyCharacter") + val ALLOW_UNQUOTED_CONTROL_CHARS = newOption("allowUnquotedControlChars") + val COMPRESSION = newOption("compression") + val MODE = newOption("mode") + val DROP_FIELD_IF_ALL_NULL = newOption("dropFieldIfAllNull") + val IGNORE_NULL_FIELDS = newOption("ignoreNullFields") + val LOCALE = newOption("locale") + val DATE_FORMAT = newOption("dateFormat") + val TIMESTAMP_FORMAT = newOption("timestampFormat") + val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat") + val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback") + val MULTI_LINE = newOption("multiLine") + val LINE_SEP = newOption("lineSep") + val PRETTY = newOption("pretty") + val INFER_TIMESTAMP = newOption("inferTimestamp") + val COLUMN_NAME_OF_CORRUPTED_RECORD = newOption("columnNameOfCorruptRecord") + val TIME_ZONE = newOption("timeZone") + val WRITE_NON_ASCII_CHARACTER_AS_CODEPOINT = newOption("writeNonAsciiCharacterAsCodePoint") + // Options with alternative + val ENCODING = "encoding" + val CHARSET = "charset" + newOption(ENCODING, CHARSET) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndexOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndexOptions.scala new file mode 100644 index 0000000000000..1c352e3748f21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileIndexOptions.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +object FileIndexOptions extends DataSourceOptions { + val IGNORE_MISSING_FILES = newOption(FileSourceOptions.IGNORE_MISSING_FILES) + val TIME_ZONE = newOption(DateTimeUtils.TIMEZONE_OPTION) + val RECURSIVE_FILE_LOOKUP = newOption("recursiveFileLookup") + val BASE_PATH_PARAM = newOption("basePath") + val MODIFIED_BEFORE = newOption("modifiedbefore") + val MODIFIED_AFTER = newOption("modifiedafter") + val PATH_GLOB_FILTER = newOption("pathglobfilter") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index d70c4b11bc0d7..53be85ad44844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.FileFormat.createMetadataInternalRow import org.apache.spark.sql.types.StructType @@ -43,7 +43,6 @@ abstract class PartitioningAwareFileIndex( parameters: Map[String, String], userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { - import PartitioningAwareFileIndex.BASE_PATH_PARAM /** Returns the specification of the partitions inferred from the data. */ def partitionSpec(): PartitionSpec @@ -64,7 +63,7 @@ abstract class PartitioningAwareFileIndex( pathFilters.forall(_.accept(file)) protected lazy val recursiveFileLookup: Boolean = { - caseInsensitiveMap.getOrElse("recursiveFileLookup", "false").toBoolean + caseInsensitiveMap.getOrElse(FileIndexOptions.RECURSIVE_FILE_LOOKUP, "false").toBoolean } override def listFiles( @@ -178,7 +177,7 @@ abstract class PartitioningAwareFileIndex( }.keys.toSeq val caseInsensitiveOptions = CaseInsensitiveMap(parameters) - val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) + val timeZoneId = caseInsensitiveOptions.get(FileIndexOptions.TIME_ZONE) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) PartitioningUtils.parsePartitions( @@ -248,11 +247,12 @@ abstract class PartitioningAwareFileIndex( * and the returned DataFrame will have the column of `something`. */ private def basePaths: Set[Path] = { - caseInsensitiveMap.get(BASE_PATH_PARAM).map(new Path(_)) match { + caseInsensitiveMap.get(FileIndexOptions.BASE_PATH_PARAM).map(new Path(_)) match { case Some(userDefinedBasePath) => val fs = userDefinedBasePath.getFileSystem(hadoopConf) if (!fs.isDirectory(userDefinedBasePath)) { - throw new IllegalArgumentException(s"Option '$BASE_PATH_PARAM' must be a directory") + throw new IllegalArgumentException(s"Option '${FileIndexOptions.BASE_PATH_PARAM}' " + + s"must be a directory") } val qualifiedBasePath = fs.makeQualified(userDefinedBasePath) val qualifiedBasePathStr = qualifiedBasePath.toString @@ -279,7 +279,3 @@ abstract class PartitioningAwareFileIndex( !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } } - -object PartitioningAwareFileIndex { - val BASE_PATH_PARAM = "basePath" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index ef1c2bb5b4104..1c819f07038ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.orc.OrcConf.COMPRESS -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -45,9 +45,9 @@ class OrcOptions( val compressionCodec: String = { // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` // are in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(COMPRESS.getAttribute) + val orcCompressionConf = parameters.get(ORC_COMPRESSION) val codecName = parameters - .get("compression") + .get(COMPRESSION) .orElse(orcCompressionConf) .getOrElse(sqlConf.orcCompressionCodec) .toLowerCase(Locale.ROOT) @@ -69,8 +69,10 @@ class OrcOptions( .getOrElse(sqlConf.isOrcSchemaMergingEnabled) } -object OrcOptions { - val MERGE_SCHEMA = "mergeSchema" +object OrcOptions extends DataSourceOptions { + val MERGE_SCHEMA = newOption("mergeSchema") + val ORC_COMPRESSION = newOption(COMPRESS.getAttribute) + val COMPRESSION = newOption("compression") // The ORC compression short names private val shortOrcCompressionCodecNames = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 07ed55b0b8f84..d20edbde00be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -47,9 +47,9 @@ class ParquetOptions( // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and // `spark.sql.parquet.compression.codec` // are in order of precedence from highest to lowest. - val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val parquetCompressionConf = parameters.get(PARQUET_COMPRESSION) val codecName = parameters - .get("compression") + .get(COMPRESSION) .orElse(parquetCompressionConf) .getOrElse(sqlConf.parquetCompressionCodec) .toLowerCase(Locale.ROOT) @@ -86,9 +86,7 @@ class ParquetOptions( } -object ParquetOptions { - val MERGE_SCHEMA = "mergeSchema" - +object ParquetOptions extends DataSourceOptions { // The parquet compression short names private val shortParquetCompressionCodecNames = Map( "none" -> CompressionCodecName.UNCOMPRESSED, @@ -104,15 +102,19 @@ object ParquetOptions { shortParquetCompressionCodecNames(name).name() } + val MERGE_SCHEMA = newOption("mergeSchema") + val PARQUET_COMPRESSION = newOption(ParquetOutputFormat.COMPRESSION) + val COMPRESSION = newOption("compression") + // The option controls rebasing of the DATE and TIMESTAMP values between // Julian and Proleptic Gregorian calendars. It impacts on the behaviour of the Parquet // datasource similarly to the SQL config `spark.sql.parquet.datetimeRebaseModeInRead`, // and can be set to the same values: `EXCEPTION`, `LEGACY` or `CORRECTED`. - val DATETIME_REBASE_MODE = "datetimeRebaseMode" + val DATETIME_REBASE_MODE = newOption("datetimeRebaseMode") // The option controls rebasing of the INT96 timestamp values between Julian and Proleptic // Gregorian calendars. It impacts on the behaviour of the Parquet datasource similarly to // the SQL config `spark.sql.parquet.int96RebaseModeInRead`. // The valid option values are: `EXCEPTION`, `LEGACY` or `CORRECTED`. - val INT96_REBASE_MODE = "int96RebaseMode" + val INT96_REBASE_MODE = newOption("int96RebaseMode") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala index d07e1957e8c6f..303129b4d576f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/pathFilters.scala @@ -43,10 +43,8 @@ class PathGlobFilter(filePatten: String) extends PathFilterStrategy { } object PathGlobFilter extends StrategyBuilder { - val PARAM_NAME = "pathglobfilter" - override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { - parameters.get(PARAM_NAME).map(new PathGlobFilter(_)) + parameters.get(FileIndexOptions.PATH_GLOB_FILTER).map(new PathGlobFilter(_)) } } @@ -111,12 +109,10 @@ class ModifiedBeforeFilter(thresholdTime: Long, val timeZoneId: String) object ModifiedBeforeFilter extends StrategyBuilder { import ModifiedDateFilter._ - val PARAM_NAME = "modifiedbefore" - override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { - parameters.get(PARAM_NAME).map { value => + parameters.get(FileIndexOptions.MODIFIED_BEFORE).map { value => val timeZoneId = getTimeZoneId(parameters) - val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + val thresholdTime = toThreshold(value, timeZoneId, FileIndexOptions.MODIFIED_BEFORE) new ModifiedBeforeFilter(thresholdTime, timeZoneId) } } @@ -137,12 +133,10 @@ class ModifiedAfterFilter(thresholdTime: Long, val timeZoneId: String) object ModifiedAfterFilter extends StrategyBuilder { import ModifiedDateFilter._ - val PARAM_NAME = "modifiedafter" - override def create(parameters: CaseInsensitiveMap[String]): Option[PathFilterStrategy] = { - parameters.get(PARAM_NAME).map { value => + parameters.get(FileIndexOptions.MODIFIED_AFTER).map { value => val timeZoneId = getTimeZoneId(parameters) - val thresholdTime = toThreshold(value, timeZoneId, PARAM_NAME) + val thresholdTime = toThreshold(value, timeZoneId, FileIndexOptions.MODIFIED_AFTER) new ModifiedAfterFilter(thresholdTime, timeZoneId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala index f1a1d465d1b8c..f26f05cbe1c55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOptions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.text import java.nio.charset.{Charset, StandardCharsets} -import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs} /** @@ -44,8 +44,8 @@ class TextOptions(@transient private val parameters: CaseInsensitiveMap[String]) val encoding: Option[String] = parameters.get(ENCODING) - val lineSeparator: Option[String] = parameters.get(LINE_SEPARATOR).map { lineSep => - require(lineSep.nonEmpty, s"'$LINE_SEPARATOR' cannot be an empty string.") + val lineSeparator: Option[String] = parameters.get(LINE_SEP).map { lineSep => + require(lineSep.nonEmpty, s"'$LINE_SEP' cannot be an empty string.") lineSep } @@ -58,9 +58,9 @@ class TextOptions(@transient private val parameters: CaseInsensitiveMap[String]) lineSeparatorInRead.getOrElse("\n".getBytes(StandardCharsets.UTF_8)) } -private[datasources] object TextOptions { - val COMPRESSION = "compression" - val WHOLETEXT = "wholetext" - val ENCODING = "encoding" - val LINE_SEPARATOR = "lineSep" +private[sql] object TextOptions extends DataSourceOptions { + val COMPRESSION = newOption("compression") + val WHOLETEXT = newOption("wholetext") + val ENCODING = newOption("encoding") + val LINE_SEP = newOption("lineSep") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala index a5c1c735cbd7b..ae09095590865 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -23,7 +23,7 @@ import scala.util.Try import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.execution.datasources.{ModifiedAfterFilter, ModifiedBeforeFilter} +import org.apache.spark.sql.execution.datasources.FileIndexOptions import org.apache.spark.util.Utils /** @@ -36,7 +36,7 @@ class FileStreamOptions(parameters: CaseInsensitiveMap[String]) extends Logging checkDisallowedOptions() private def checkDisallowedOptions(): Unit = { - Seq(ModifiedBeforeFilter.PARAM_NAME, ModifiedAfterFilter.PARAM_NAME).foreach { param => + Seq(FileIndexOptions.MODIFIED_BEFORE, FileIndexOptions.MODIFIED_AFTER).foreach { param => if (parameters.contains(param)) { throw new IllegalArgumentException(s"option '$param' is not allowed in file stream sources") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 1897a347ef175..07018508b91cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -539,6 +539,18 @@ class FileIndexSuite extends SharedSparkSession { } } } + + test("SPARK-40667: validate FileIndex Options") { + assert(FileIndexOptions.getAllOptions.size == 7) + // Please add validation on any new FileIndex options here + assert(FileIndexOptions.isValidOption("ignoreMissingFiles")) + assert(FileIndexOptions.isValidOption("timeZone")) + assert(FileIndexOptions.isValidOption("recursiveFileLookup")) + assert(FileIndexOptions.isValidOption("basePath")) + assert(FileIndexOptions.isValidOption("modifiedbefore")) + assert(FileIndexOptions.isValidOption("modifiedafter")) + assert(FileIndexOptions.isValidOption("pathglobfilter")) + } } object DeletionRaceFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index b644f6262304a..0c561fd8e7b4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -38,6 +38,7 @@ import org.apache.logging.log4j.Level import org.apache.spark.{SparkConf, SparkException, SparkUpgradeException, TestUtils} import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Encoders, QueryTest, Row} +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.execution.datasources.CommonFileDataSourceSuite import org.apache.spark.sql.internal.SQLConf @@ -3080,6 +3081,57 @@ abstract class CSVSuite } } } + + test("SPARK-40667: validate CSV Options") { + assert(CSVOptions.getAllOptions.size == 38) + // Please add validation on any new CSV options here + assert(CSVOptions.isValidOption("header")) + assert(CSVOptions.isValidOption("inferSchema")) + assert(CSVOptions.isValidOption("ignoreLeadingWhiteSpace")) + assert(CSVOptions.isValidOption("ignoreTrailingWhiteSpace")) + assert(CSVOptions.isValidOption("prefersDate")) + assert(CSVOptions.isValidOption("escapeQuotes")) + assert(CSVOptions.isValidOption("quoteAll")) + assert(CSVOptions.isValidOption("enforceSchema")) + assert(CSVOptions.isValidOption("quote")) + assert(CSVOptions.isValidOption("escape")) + assert(CSVOptions.isValidOption("comment")) + assert(CSVOptions.isValidOption("maxColumns")) + assert(CSVOptions.isValidOption("maxCharsPerColumn")) + assert(CSVOptions.isValidOption("mode")) + assert(CSVOptions.isValidOption("charToEscapeQuoteEscaping")) + assert(CSVOptions.isValidOption("locale")) + assert(CSVOptions.isValidOption("dateFormat")) + assert(CSVOptions.isValidOption("timestampFormat")) + assert(CSVOptions.isValidOption("timestampNTZFormat")) + assert(CSVOptions.isValidOption("enableDateTimeParsingFallback")) + assert(CSVOptions.isValidOption("multiLine")) + assert(CSVOptions.isValidOption("samplingRatio")) + assert(CSVOptions.isValidOption("emptyValue")) + assert(CSVOptions.isValidOption("lineSep")) + assert(CSVOptions.isValidOption("inputBufferSize")) + assert(CSVOptions.isValidOption("columnNameOfCorruptRecord")) + assert(CSVOptions.isValidOption("nullValue")) + assert(CSVOptions.isValidOption("nanValue")) + assert(CSVOptions.isValidOption("positiveInf")) + assert(CSVOptions.isValidOption("negativeInf")) + assert(CSVOptions.isValidOption("timeZone")) + assert(CSVOptions.isValidOption("unescapedQuoteHandling")) + assert(CSVOptions.isValidOption("encoding")) + assert(CSVOptions.isValidOption("charset")) + assert(CSVOptions.isValidOption("compression")) + assert(CSVOptions.isValidOption("codec")) + assert(CSVOptions.isValidOption("sep")) + assert(CSVOptions.isValidOption("delimiter")) + // Please add validation on any new parquet options with alternative here + assert(CSVOptions.getAlternativeOption("sep").contains("delimiter")) + assert(CSVOptions.getAlternativeOption("delimiter").contains("sep")) + assert(CSVOptions.getAlternativeOption("encoding").contains("charset")) + assert(CSVOptions.getAlternativeOption("charset").contains("encoding")) + assert(CSVOptions.getAlternativeOption("compression").contains("codec")) + assert(CSVOptions.getAlternativeOption("codec").contains("compression")) + assert(CSVOptions.getAlternativeOption("prefersDate").isEmpty) + } } class CSVv1Suite extends CSVSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 212d590813720..d50fce0f6a9a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3381,6 +3381,43 @@ abstract class JsonSuite } } } + + test("SPARK-40667: validate JSON Options") { + assert(JSONOptions.getAllOptions.size == 28) + // Please add validation on any new Json options here + assert(JSONOptions.isValidOption("samplingRatio")) + assert(JSONOptions.isValidOption("primitivesAsString")) + assert(JSONOptions.isValidOption("prefersDecimal")) + assert(JSONOptions.isValidOption("allowComments")) + assert(JSONOptions.isValidOption("allowUnquotedFieldNames")) + assert(JSONOptions.isValidOption("allowSingleQuotes")) + assert(JSONOptions.isValidOption("allowNumericLeadingZeros")) + assert(JSONOptions.isValidOption("allowNonNumericNumbers")) + assert(JSONOptions.isValidOption("allowBackslashEscapingAnyCharacter")) + assert(JSONOptions.isValidOption("allowUnquotedControlChars")) + assert(JSONOptions.isValidOption("compression")) + assert(JSONOptions.isValidOption("mode")) + assert(JSONOptions.isValidOption("dropFieldIfAllNull")) + assert(JSONOptions.isValidOption("ignoreNullFields")) + assert(JSONOptions.isValidOption("locale")) + assert(JSONOptions.isValidOption("dateFormat")) + assert(JSONOptions.isValidOption("timestampFormat")) + assert(JSONOptions.isValidOption("timestampNTZFormat")) + assert(JSONOptions.isValidOption("enableDateTimeParsingFallback")) + assert(JSONOptions.isValidOption("multiLine")) + assert(JSONOptions.isValidOption("lineSep")) + assert(JSONOptions.isValidOption("pretty")) + assert(JSONOptions.isValidOption("inferTimestamp")) + assert(JSONOptions.isValidOption("columnNameOfCorruptRecord")) + assert(JSONOptions.isValidOption("timeZone")) + assert(JSONOptions.isValidOption("writeNonAsciiCharacterAsCodePoint")) + assert(JSONOptions.isValidOption("encoding")) + assert(JSONOptions.isValidOption("charset")) + // Please add validation on any new Json options with alternative here + assert(JSONOptions.getAlternativeOption("encoding").contains("charset")) + assert(JSONOptions.getAlternativeOption("charset").contains("encoding")) + assert(JSONOptions.getAlternativeOption("dateFormat").isEmpty) + } } class JsonV1Suite extends JsonSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 19477adec3960..94ce3d77962ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -580,6 +580,14 @@ abstract class OrcSuite "ORC sources shall write an empty file contains meta if necessary") } } + + test("SPARK-40667: validate Orc Options") { + assert(OrcOptions.getAllOptions.size == 3) + // Please add validation on any new Orc options here + assert(OrcOptions.isValidOption("mergeSchema")) + assert(OrcOptions.isValidOption("orc.compress")) + assert(OrcOptions.isValidOption("compression")) + } } abstract class OrcSourceSuite extends OrcSuite with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 8b8c4918e5bb0..fea986cc8e2de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1485,6 +1485,16 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + test("SPARK-40667: validate Parquet Options") { + assert(ParquetOptions.getAllOptions.size == 5) + // Please add validation on any new parquet options here + assert(ParquetOptions.isValidOption("mergeSchema")) + assert(ParquetOptions.isValidOption("compression")) + assert(ParquetOptions.isValidOption("parquet.compression")) + assert(ParquetOptions.isValidOption("datetimeRebaseMode")) + assert(ParquetOptions.isValidOption("int96RebaseMode")) + } + test("SPARK-23173 Writing a file with data converted from JSON with and incorrect user schema") { withTempPath { file => val jsonData = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 1eb32ed285799..ff6b9aadf7cfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -236,6 +236,15 @@ abstract class TextSuite extends QueryTest with SharedSparkSession with CommonFi assert(data(3) == Row("\"doh\"")) assert(data.length == 4) } + + test("SPARK-40667: validate Text Options") { + assert(TextOptions.getAllOptions.size == 4) + // Please add validation on any new Text options here + assert(TextOptions.isValidOption("compression")) + assert(TextOptions.isValidOption("wholetext")) + assert(TextOptions.isValidOption("encoding")) + assert(TextOptions.isValidOption("lineSep")) + } } class TextV1Suite extends TextSuite {