diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 33ed0d5493e0e..f6af9ccc41b39 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -248,7 +248,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { * - This will throw an exception is the config is not optional and the value is not set. */ private[spark] def get[T](entry: ConfigEntry[T]): T = { - entry.readFrom(this) + entry.readFrom(settings, getenv) } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 5d50e3851a9f0..0f5c8a9e02ab8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -116,11 +116,17 @@ private[spark] class TypedConfigBuilder[T]( /** Creates a [[ConfigEntry]] that has a default value. */ def createWithDefault(default: T): ConfigEntry[T] = { - val transformedDefault = converter(stringConverter(default)) - val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, - stringConverter, parent._doc, parent._public) - parent._onCreate.foreach(_(entry)) - entry + // Treat "String" as a special case, so that both createWithDefault and createWithDefaultString + // behave the same w.r.t. variable expansion of default values. + if (default.isInstanceOf[String]) { + createWithDefaultString(default.asInstanceOf[String]) + } else { + val transformedDefault = converter(stringConverter(default)) + val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_(entry)) + entry + } } /** @@ -128,8 +134,7 @@ private[spark] class TypedConfigBuilder[T]( * [[String]] and must be a valid value for the entry. */ def createWithDefaultString(default: String): ConfigEntry[T] = { - val typedDefault = converter(default) - val entry = new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, + val entry = new ConfigEntryWithDefaultString[T](parent.key, default, converter, stringConverter, parent._doc, parent._public) parent._onCreate.foreach(_(entry)) entry diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index f7296b487c0e9..e2e23b3c3c32f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -17,11 +17,35 @@ package org.apache.spark.internal.config +import java.util.{Map => JMap} + +import scala.util.matching.Regex + import org.apache.spark.SparkConf /** * An entry contains all meta information for a configuration. * + * Config options created using this feature support variable expansion. If the config value + * contains variable references of the form "${prefix:variableName}", the reference will be replaced + * with the value of the variable depending on the prefix. The prefix can be one of: + * + * - no prefix: if the config key starts with "spark", looks for the value in the Spark config + * - system: looks for the value in the system properties + * - env: looks for the value in the environment + * + * So referencing "${spark.master}" will look for the value of "spark.master" in the Spark + * configuration, while referencing "${env:MASTER}" will read the value from the "MASTER" + * environment variable. + * + * For known Spark configuration keys (i.e. those created using `ConfigBuilder`), references + * will also consider the default value when it exists. + * + * If the reference cannot be resolved, the original string will be retained. + * + * Variable expansion is also applied to the default values of config entries that have a default + * value declared as a string. + * * @param key the key for the configuration * @param defaultValue the default value for the configuration * @param valueConverter how to convert a string to the value. It should throw an exception if the @@ -42,17 +66,27 @@ private[spark] abstract class ConfigEntry[T] ( val doc: String, val isPublic: Boolean) { + import ConfigEntry._ + + registerEntry(this) + def defaultValueString: String - def readFrom(conf: SparkConf): T + def readFrom(conf: JMap[String, String], getenv: String => String): T - // This is used by SQLConf, since it doesn't use SparkConf to store settings and thus cannot - // use readFrom(). def defaultValue: Option[T] = None override def toString: String = { s"ConfigEntry(key=$key, defaultValue=$defaultValueString, doc=$doc, public=$isPublic)" } + + protected def readAndExpand( + conf: JMap[String, String], + getenv: String => String, + usedRefs: Set[String] = Set()): Option[String] = { + Option(conf.get(key)).map(expand(_, conf, getenv, usedRefs)) + } + } private class ConfigEntryWithDefault[T] ( @@ -68,12 +102,36 @@ private class ConfigEntryWithDefault[T] ( override def defaultValueString: String = stringConverter(_defaultValue) - override def readFrom(conf: SparkConf): T = { - conf.getOption(key).map(valueConverter).getOrElse(_defaultValue) + def readFrom(conf: JMap[String, String], getenv: String => String): T = { + readAndExpand(conf, getenv).map(valueConverter).getOrElse(_defaultValue) } } +private class ConfigEntryWithDefaultString[T] ( + key: String, + _defaultValue: String, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + + override def defaultValue: Option[T] = Some(valueConverter(_defaultValue)) + + override def defaultValueString: String = _defaultValue + + def readFrom(conf: JMap[String, String], getenv: String => String): T = { + Option(conf.get(key)) + .orElse(Some(_defaultValue)) + .map(ConfigEntry.expand(_, conf, getenv, Set())) + .map(valueConverter) + .get + } + +} + + /** * A config entry that does not have a default value. */ @@ -88,7 +146,9 @@ private[spark] class OptionalConfigEntry[T]( override def defaultValueString: String = "" - override def readFrom(conf: SparkConf): Option[T] = conf.getOption(key).map(rawValueConverter) + override def readFrom(conf: JMap[String, String], getenv: String => String): Option[T] = { + readAndExpand(conf, getenv).map(rawValueConverter) + } } @@ -99,13 +159,67 @@ private class FallbackConfigEntry[T] ( key: String, doc: String, isPublic: Boolean, - private val fallback: ConfigEntry[T]) + private[config] val fallback: ConfigEntry[T]) extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { override def defaultValueString: String = s"" - override def readFrom(conf: SparkConf): T = { - conf.getOption(key).map(valueConverter).getOrElse(fallback.readFrom(conf)) + override def readFrom(conf: JMap[String, String], getenv: String => String): T = { + Option(conf.get(key)).map(valueConverter).getOrElse(fallback.readFrom(conf, getenv)) + } + +} + +private object ConfigEntry { + + private val knownConfigs = new java.util.concurrent.ConcurrentHashMap[String, ConfigEntry[_]]() + + private val REF_RE = "\\$\\{(?:(\\w+?):)?(\\S+?)\\}".r + + def registerEntry(entry: ConfigEntry[_]): Unit = { + val existing = knownConfigs.putIfAbsent(entry.key, entry) + require(existing == null, s"Config entry ${entry.key} already registered!") + } + + def findEntry(key: String): ConfigEntry[_] = knownConfigs.get(key) + + /** + * Expand the `value` according to the rules explained in ConfigEntry. + */ + def expand( + value: String, + conf: JMap[String, String], + getenv: String => String, + usedRefs: Set[String]): String = { + REF_RE.replaceAllIn(value, { m => + val prefix = m.group(1) + val name = m.group(2) + val replacement = prefix match { + case null => + require(!usedRefs.contains(name), s"Circular reference in $value: $name") + if (name.startsWith("spark.")) { + Option(findEntry(name)) + .flatMap(_.readAndExpand(conf, getenv, usedRefs = usedRefs + name)) + .orElse(Option(conf.get(name))) + .orElse(defaultValueString(name)) + } else { + None + } + case "system" => sys.props.get(name) + case "env" => Option(getenv(name)) + case _ => None + } + Regex.quoteReplacement(replacement.getOrElse(m.matched)) + }) + } + + private def defaultValueString(key: String): Option[String] = { + findEntry(key) match { + case e: ConfigEntryWithDefault[_] => Some(e.defaultValueString) + case e: ConfigEntryWithDefaultString[_] => Some(e.defaultValueString) + case e: FallbackConfigEntry[_] => defaultValueString(e.fallback.key) + case _ => None + } } } diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 337fd7e85e81c..ebdb69f31e360 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -19,14 +19,21 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.network.util.ByteUnit class ConfigEntrySuite extends SparkFunSuite { + private val PREFIX = "spark.ConfigEntrySuite" + + private def testKey(name: String): String = s"$PREFIX.$name" + test("conf entry: int") { val conf = new SparkConf() - val iConf = ConfigBuilder("spark.int").intConf.createWithDefault(1) + val iConf = ConfigBuilder(testKey("int")).intConf.createWithDefault(1) assert(conf.get(iConf) === 1) conf.set(iConf, 2) assert(conf.get(iConf) === 2) @@ -34,21 +41,21 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: long") { val conf = new SparkConf() - val lConf = ConfigBuilder("spark.long").longConf.createWithDefault(0L) + val lConf = ConfigBuilder(testKey("long")).longConf.createWithDefault(0L) conf.set(lConf, 1234L) assert(conf.get(lConf) === 1234L) } test("conf entry: double") { val conf = new SparkConf() - val dConf = ConfigBuilder("spark.double").doubleConf.createWithDefault(0.0) + val dConf = ConfigBuilder(testKey("double")).doubleConf.createWithDefault(0.0) conf.set(dConf, 20.0) assert(conf.get(dConf) === 20.0) } test("conf entry: boolean") { val conf = new SparkConf() - val bConf = ConfigBuilder("spark.boolean").booleanConf.createWithDefault(false) + val bConf = ConfigBuilder(testKey("boolean")).booleanConf.createWithDefault(false) assert(!conf.get(bConf)) conf.set(bConf, true) assert(conf.get(bConf)) @@ -56,7 +63,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: optional") { val conf = new SparkConf() - val optionalConf = ConfigBuilder("spark.optional").intConf.createOptional + val optionalConf = ConfigBuilder(testKey("optional")).intConf.createOptional assert(conf.get(optionalConf) === None) conf.set(optionalConf, 1) assert(conf.get(optionalConf) === Some(1)) @@ -64,8 +71,8 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: fallback") { val conf = new SparkConf() - val parentConf = ConfigBuilder("spark.int").intConf.createWithDefault(1) - val confWithFallback = ConfigBuilder("spark.fallback").fallbackConf(parentConf) + val parentConf = ConfigBuilder(testKey("parent")).intConf.createWithDefault(1) + val confWithFallback = ConfigBuilder(testKey("fallback")).fallbackConf(parentConf) assert(conf.get(confWithFallback) === 1) conf.set(confWithFallback, 2) assert(conf.get(parentConf) === 1) @@ -74,7 +81,8 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: time") { val conf = new SparkConf() - val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).createWithDefaultString("1h") + val time = ConfigBuilder(testKey("time")).timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1h") assert(conf.get(time) === 3600L) conf.set(time.key, "1m") assert(conf.get(time) === 60L) @@ -82,7 +90,8 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: bytes") { val conf = new SparkConf() - val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).createWithDefaultString("1m") + val bytes = ConfigBuilder(testKey("bytes")).bytesConf(ByteUnit.KiB) + .createWithDefaultString("1m") assert(conf.get(bytes) === 1024L) conf.set(bytes.key, "1k") assert(conf.get(bytes) === 1L) @@ -90,7 +99,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: string seq") { val conf = new SparkConf() - val seq = ConfigBuilder("spark.seq").stringConf.toSequence.createWithDefault(Seq()) + val seq = ConfigBuilder(testKey("seq")).stringConf.toSequence.createWithDefault(Seq()) conf.set(seq.key, "1,,2, 3 , , 4") assert(conf.get(seq) === Seq("1", "2", "3", "4")) conf.set(seq, Seq("1", "2")) @@ -99,7 +108,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: int seq") { val conf = new SparkConf() - val seq = ConfigBuilder("spark.seq").intConf.toSequence.createWithDefault(Seq()) + val seq = ConfigBuilder(testKey("intSeq")).intConf.toSequence.createWithDefault(Seq()) conf.set(seq.key, "1,,2, 3 , , 4") assert(conf.get(seq) === Seq(1, 2, 3, 4)) conf.set(seq, Seq(1, 2)) @@ -108,7 +117,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: transformation") { val conf = new SparkConf() - val transformationConf = ConfigBuilder("spark.transformation") + val transformationConf = ConfigBuilder(testKey("transformation")) .stringConf .transform(_.toLowerCase()) .createWithDefault("FOO") @@ -120,7 +129,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: valid values check") { val conf = new SparkConf() - val enum = ConfigBuilder("spark.enum") + val enum = ConfigBuilder(testKey("enum")) .stringConf .checkValues(Set("a", "b", "c")) .createWithDefault("a") @@ -138,7 +147,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: conversion error") { val conf = new SparkConf() - val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.createOptional + val conversionTest = ConfigBuilder(testKey("conversionTest")).doubleConf.createOptional conf.set(conversionTest.key, "abc") val conversionError = intercept[IllegalArgumentException] { conf.get(conversionTest) @@ -148,8 +157,81 @@ class ConfigEntrySuite extends SparkFunSuite { test("default value handling is null-safe") { val conf = new SparkConf() - val stringConf = ConfigBuilder("spark.string").stringConf.createWithDefault(null) + val stringConf = ConfigBuilder(testKey("string")).stringConf.createWithDefault(null) assert(conf.get(stringConf) === null) } + test("variable expansion") { + val env = Map("ENV1" -> "env1") + val conf = HashMap("spark.value1" -> "value1", "spark.value2" -> "value2") + + def getenv(key: String): String = env.getOrElse(key, null) + + def expand(value: String): String = ConfigEntry.expand(value, conf.asJava, getenv, Set()) + + assert(expand("${spark.value1}") === "value1") + assert(expand("spark.value1 is: ${spark.value1}") === "spark.value1 is: value1") + assert(expand("${spark.value1} ${spark.value2}") === "value1 value2") + assert(expand("${spark.value3}") === "${spark.value3}") + + // Make sure anything that is not in the "spark." namespace is ignored. + conf("notspark.key") = "value" + assert(expand("${notspark.key}") === "${notspark.key}") + + assert(expand("${env:ENV1}") === "env1") + assert(expand("${system:user.name}") === sys.props("user.name")) + + val stringConf = ConfigBuilder(testKey("stringForExpansion")) + .stringConf + .createWithDefault("string1") + val optionalConf = ConfigBuilder(testKey("optionForExpansion")) + .stringConf + .createOptional + val intConf = ConfigBuilder(testKey("intForExpansion")) + .intConf + .createWithDefault(42) + val fallbackConf = ConfigBuilder(testKey("fallbackForExpansion")) + .fallbackConf(intConf) + + assert(expand("${" + stringConf.key + "}") === "string1") + assert(expand("${" + optionalConf.key + "}") === "${" + optionalConf.key + "}") + assert(expand("${" + intConf.key + "}") === "42") + assert(expand("${" + fallbackConf.key + "}") === "42") + + conf(optionalConf.key) = "string2" + assert(expand("${" + optionalConf.key + "}") === "string2") + + conf(fallbackConf.key) = "84" + assert(expand("${" + fallbackConf.key + "}") === "84") + + assert(expand("${spark.value1") === "${spark.value1") + + // Unknown prefixes. + assert(expand("${unknown:value}") === "${unknown:value}") + + // Chained references. + val conf1 = ConfigBuilder(testKey("conf1")) + .stringConf + .createWithDefault("value1") + val conf2 = ConfigBuilder(testKey("conf2")) + .stringConf + .createWithDefault("value2") + + conf(conf2.key) = "${" + conf1.key + "}" + assert(expand("${" + conf2.key + "}") === conf1.defaultValueString) + + // Circular references. + conf(conf1.key) = "${" + conf2.key + "}" + val e = intercept[IllegalArgumentException] { + expand("${" + conf2.key + "}") + } + assert(e.getMessage().contains("Circular")) + + // Default string values with variable references. + val parameterizedStringConf = ConfigBuilder(testKey("stringWithParams")) + .stringConf + .createWithDefault("${spark.value1}") + assert(parameterizedStringConf.readFrom(conf.asJava, getenv) === conf("spark.value1")) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 14a1680fafa3a..12a11ad746218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -677,9 +677,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) - def warehousePath: String = { - getConf(WAREHOUSE_PATH).replace("${system:user.dir}", System.getProperty("user.dir")) - } + def warehousePath: String = getConf(WAREHOUSE_PATH) override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) @@ -738,8 +736,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { */ def getConf[T](entry: ConfigEntry[T]): T = { require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). - getOrElse(throw new NoSuchElementException(entry.key)) + entry.readFrom(settings, System.getenv) } /** @@ -748,7 +745,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { */ def getConf[T](entry: OptionalConfigEntry[T]): Option[T] = { require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - Option(settings.get(entry.key)).map(entry.rawValueConverter) + entry.readFrom(settings, System.getenv) } /**