From fe904e6973b7a8fdadc5e253a6a74e8ccb359287 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 4 Dec 2024 16:11:42 -0400 Subject: [PATCH] [SPARK-49709][CONNECT][SQL] Support ConfigEntry in the RuntimeConfig interface ### What changes were proposed in this pull request? This PR adds support for ConfigEntry to the RuntimeConfig interface. This was removed in https://github.com/apache/spark/pull/47980. ### Why are the changes needed? This functionality is used a lot by Spark libraries. Removing them caused friction, and adding them does not pollute the RuntimeConfig interface. ### Does this PR introduce _any_ user-facing change? No. This is developer API. ### How was this patch tested? I have added tests cases for Connect and Classic. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49062 from hvanhovell/SPARK-49709. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../spark/internal/config/ConfigBuilder.scala | 4 +- .../spark/internal/config/ConfigEntry.scala | 0 .../internal/config/ConfigProvider.scala | 17 ----- .../spark/internal/config/ConfigReader.scala | 0 .../apache/spark/util/SparkStringUtils.scala | 26 ++++++++ .../sql/internal/ConnectRuntimeConfig.scala | 62 ++++++++++++++----- .../apache/spark/sql/ClientE2ETestSuite.scala | 23 +++++++ .../internal/config/SparkConfigProvider.scala | 35 +++++++++++ .../scala/org/apache/spark/util/Utils.scala | 7 +-- .../org/apache/spark/sql/RuntimeConfig.scala | 25 ++++++++ .../sql/internal/RuntimeConfigImpl.scala | 20 +++++- .../apache/spark/sql/RuntimeConfigSuite.scala | 22 +++++++ 12 files changed, 201 insertions(+), 40 deletions(-) rename {core => common/utils}/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala (99%) rename {core => common/utils}/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala (100%) rename {core => common/utils}/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala (78%) rename {core => common/utils}/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala (100%) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/SparkStringUtils.scala create mode 100644 core/src/main/scala/org/apache/spark/internal/config/SparkConfigProvider.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala rename to common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index f50cc0f88842a..d3e975d1782f0 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -24,7 +24,7 @@ import scala.util.matching.Regex import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.network.util.{ByteUnit, JavaUtils} -import org.apache.spark.util.Utils +import org.apache.spark.util.SparkStringUtils private object ConfigHelpers { @@ -47,7 +47,7 @@ private object ConfigHelpers { } def stringToSeq[T](str: String, converter: String => T): Seq[T] = { - Utils.stringToSeq(str).map(converter) + SparkStringUtils.stringToSeq(str).map(converter) } def seqToString[T](v: Seq[T], stringConverter: T => String): String = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala similarity index 100% rename from core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala rename to common/utils/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala similarity index 78% rename from core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala rename to common/utils/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala index 392f9d56e7f51..fef019ef1f560 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigProvider.scala @@ -19,8 +19,6 @@ package org.apache.spark.internal.config import java.util.{Map => JMap} -import org.apache.spark.SparkConf - /** * A source of configuration values. */ @@ -47,18 +45,3 @@ private[spark] class MapProvider(conf: JMap[String, String]) extends ConfigProvi override def get(key: String): Option[String] = Option(conf.get(key)) } - -/** - * A config provider that only reads Spark config keys. - */ -private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends ConfigProvider { - - override def get(key: String): Option[String] = { - if (key.startsWith("spark.")) { - Option(conf.get(key)).orElse(SparkConf.getDeprecatedConfig(key, conf)) - } else { - None - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala similarity index 100% rename from core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala rename to common/utils/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkStringUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkStringUtils.scala new file mode 100644 index 0000000000000..6915f373b84e5 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkStringUtils.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +trait SparkStringUtils { + def stringToSeq(str: String): Seq[String] = { + import org.apache.spark.util.ArrayImplicits._ + str.split(",").map(_.trim()).filter(_.nonEmpty).toImmutableArraySeq + } +} + +object SparkStringUtils extends SparkStringUtils diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala index be1a13cb2fed2..74348e8e015e2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.connect.proto.{ConfigRequest, ConfigResponse, KeyValue} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{ConfigEntry, ConfigReader, OptionalConfigEntry} import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.connect.client.SparkConnectClient @@ -28,7 +29,7 @@ import org.apache.spark.sql.connect.client.SparkConnectClient */ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) extends RuntimeConfig - with Logging { + with Logging { self => /** @inheritdoc */ def set(key: String, value: String): Unit = { @@ -37,6 +38,13 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) } } + /** @inheritdoc */ + override private[sql] def set[T](entry: ConfigEntry[T], value: T): Unit = { + require(entry != null, "entry cannot be null") + require(value != null, s"value cannot be null for key: ${entry.key}") + set(entry.key, entry.stringConverter(value)) + } + /** @inheritdoc */ @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = getOption(key).getOrElse { @@ -45,11 +53,39 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) /** @inheritdoc */ def get(key: String, default: String): String = { - executeConfigRequestSingleValue { builder => - builder.getGetWithDefaultBuilder.addPairsBuilder().setKey(key).setValue(default) + val kv = executeConfigRequestSinglePair { builder => + val pairsBuilder = builder.getGetWithDefaultBuilder + .addPairsBuilder() + .setKey(key) + if (default != null) { + pairsBuilder.setValue(default) + } + } + if (kv.hasValue) { + kv.getValue + } else { + default } } + /** @inheritdoc */ + override private[sql] def get[T](entry: ConfigEntry[T]): T = { + require(entry != null, "entry cannot be null") + entry.readFrom(reader) + } + + /** @inheritdoc */ + override private[sql] def get[T](entry: OptionalConfigEntry[T]): Option[T] = { + require(entry != null, "entry cannot be null") + entry.readFrom(reader) + } + + /** @inheritdoc */ + override private[sql] def get[T](entry: ConfigEntry[T], default: T): T = { + require(entry != null, "entry cannot be null") + Option(get(entry.key, null)).map(entry.valueConverter).getOrElse(default) + } + /** @inheritdoc */ def getAll: Map[String, String] = { val response = executeConfigRequest { builder => @@ -65,11 +101,11 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) /** @inheritdoc */ def getOption(key: String): Option[String] = { - val pair = executeConfigRequestSinglePair { builder => + val kv = executeConfigRequestSinglePair { builder => builder.getGetOptionBuilder.addKeys(key) } - if (pair.hasValue) { - Option(pair.getValue) + if (kv.hasValue) { + Option(kv.getValue) } else { None } @@ -84,17 +120,11 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) /** @inheritdoc */ def isModifiable(key: String): Boolean = { - val modifiable = executeConfigRequestSingleValue { builder => + val kv = executeConfigRequestSinglePair { builder => builder.getIsModifiableBuilder.addKeys(key) } - java.lang.Boolean.valueOf(modifiable) - } - - private def executeConfigRequestSingleValue( - f: ConfigRequest.Operation.Builder => Unit): String = { - val pair = executeConfigRequestSinglePair(f) - require(pair.hasValue, "The returned pair does not have a value set") - pair.getValue + require(kv.hasValue, "The returned pair does not have a value set") + java.lang.Boolean.valueOf(kv.getValue) } private def executeConfigRequestSinglePair( @@ -113,4 +143,6 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) } response } + + private val reader = new ConfigReader((key: String) => Option(self.get(key, null))) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 92b5808f4d626..c7979b8e033ea 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema @@ -1006,8 +1007,12 @@ class ClientE2ETestSuite test("RuntimeConfig") { intercept[NoSuchElementException](spark.conf.get("foo.bar")) assert(spark.conf.getOption("foo.bar").isEmpty) + assert(spark.conf.get("foo.bar", "nope") == "nope") + assert(spark.conf.get("foo.bar", null) == null) spark.conf.set("foo.bar", value = true) assert(spark.conf.getOption("foo.bar") === Option("true")) + assert(spark.conf.get("foo.bar", "nope") === "true") + assert(spark.conf.get("foo.bar", null) === "true") spark.conf.set("foo.bar.numBaz", 100L) assert(spark.conf.get("foo.bar.numBaz") === "100") spark.conf.set("foo.bar.name", "donkey") @@ -1020,6 +1025,24 @@ class ClientE2ETestSuite assert(spark.conf.isModifiable("spark.sql.ansi.enabled")) assert(!spark.conf.isModifiable("spark.sql.globalTempDatabase")) intercept[Exception](spark.conf.set("spark.sql.globalTempDatabase", "/dev/null")) + + val entry = ConfigBuilder("my.simple.conf").intConf.createOptional + intercept[NoSuchElementException](spark.conf.get(entry.key)) + assert(spark.conf.get(entry).isEmpty) + assert(spark.conf.get(entry, Option(55)) === Option(55)) + spark.conf.set(entry, Option(33)) + assert(spark.conf.get(entry.key) === "33") + assert(spark.conf.get(entry) === Option(33)) + assert(spark.conf.get(entry, Option(55)) === Option(33)) + + val entryWithDefault = ConfigBuilder("my.important.conf").intConf.createWithDefault(10) + intercept[NoSuchElementException](spark.conf.get(entryWithDefault.key)) + assert(spark.conf.get(entryWithDefault) === 10) + assert(spark.conf.get(entryWithDefault, 11) === 11) + spark.conf.set(entryWithDefault, 12) + assert(spark.conf.get(entryWithDefault.key) === "12") + assert(spark.conf.get(entryWithDefault) === 12) + assert(spark.conf.get(entryWithDefault, 11) === 12) } test("SparkVersion") { diff --git a/core/src/main/scala/org/apache/spark/internal/config/SparkConfigProvider.scala b/core/src/main/scala/org/apache/spark/internal/config/SparkConfigProvider.scala new file mode 100644 index 0000000000000..8739c87a65877 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/config/SparkConfigProvider.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.internal.config + +import java.util.{Map => JMap} + +import org.apache.spark.SparkConf + +/** + * A config provider that only reads Spark config keys. + */ +private[spark] class SparkConfigProvider(conf: JMap[String, String]) extends ConfigProvider { + + override def get(key: String): Option[String] = { + if (key.startsWith("spark.")) { + Option(conf.get(key)).orElse(SparkConf.getDeprecatedConfig(key, conf)) + } else { + None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b2cf99241fdee..9e7ba6d879aa0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -103,7 +103,8 @@ private[spark] object Utils with SparkErrorUtils with SparkFileUtils with SparkSerDeUtils - with SparkStreamUtils { + with SparkStreamUtils + with SparkStringUtils { private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler @volatile private var cachedLocalDir: String = "" @@ -2799,10 +2800,6 @@ private[spark] object Utils } } - def stringToSeq(str: String): Seq[String] = { - str.split(",").map(_.trim()).filter(_.nonEmpty).toImmutableArraySeq - } - /** * Create instances of extension classes. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 9e6e0e97f0302..091fbf20a0a7f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Stable +import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} /** * Runtime configuration interface for Spark. To access this, use `SparkSession.conf`. @@ -53,6 +54,11 @@ abstract class RuntimeConfig { set(key, value.toString) } + /** + * Sets the given Spark runtime configuration property. + */ + private[sql] def set[T](entry: ConfigEntry[T], value: T): Unit + /** * Returns the value of Spark runtime configuration property for the given key. If the key is * not set yet, return its default value if possible, otherwise `NoSuchElementException` will be @@ -74,6 +80,25 @@ abstract class RuntimeConfig { */ def get(key: String, default: String): String + /** + * Returns the value of Spark runtime configuration property for the given key. If the key is + * not set yet, return `defaultValue` in [[ConfigEntry]]. + */ + @throws[NoSuchElementException]("if the key is not set") + private[sql] def get[T](entry: ConfigEntry[T]): T + + /** + * Returns the value of Spark runtime configuration property for the given key. If the key is + * not set yet, return None. + */ + private[sql] def get[T](entry: OptionalConfigEntry[T]): Option[T] + + /** + * Returns the value of Spark runtime configuration property for the given key. If the key is + * not set yet, return the user given `default`. + */ + private[sql] def get[T](entry: ConfigEntry[T], default: T): T + /** * Returns all properties set in this conf. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index 1739b86c8dcb4..b2004215a99f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.SPARK_DOC_ROOT import org.apache.spark.annotation.Stable -import org.apache.spark.internal.config.{ConfigEntry, DEFAULT_PARALLELISM} +import org.apache.spark.internal.config.{ConfigEntry, DEFAULT_PARALLELISM, OptionalConfigEntry} import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.errors.QueryCompilationErrors @@ -41,6 +41,12 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends sqlConf.setConfString(key, value) } + /** @inheritdoc */ + override private[sql] def set[T](entry: ConfigEntry[T], value: T): Unit = { + requireNonStaticConf(entry.key) + sqlConf.setConf(entry, value) + } + /** @inheritdoc */ @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = { @@ -57,6 +63,18 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends sqlConf.getAllConfs } + /** @inheritdoc */ + override private[sql] def get[T](entry: ConfigEntry[T]): T = + sqlConf.getConf(entry) + + /** @inheritdoc */ + override private[sql] def get[T](entry: OptionalConfigEntry[T]): Option[T] = + sqlConf.getConf(entry) + + /** @inheritdoc */ + override private[sql] def get[T](entry: ConfigEntry[T], default: T): T = + sqlConf.getConf(entry, default) + private[sql] def getAllAsJava: java.util.Map[String, String] = { getAll.asJava } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index c80787c40c487..ce3ac9b8834bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -108,4 +108,26 @@ class RuntimeConfigSuite extends SparkFunSuite { // this set should not fail conf.set(DEFAULT_PARALLELISM.key, "1") } + + test("config entry") { + val conf = newConf() + + val entry = SQLConf.FILES_MAX_PARTITION_NUM + assert(conf.get(entry.key) === null) + assert(conf.get(entry).isEmpty) + assert(conf.get(entry, Option(55)) === Option(55)) + conf.set(entry, Option(33)) + assert(conf.get(entry.key) === "33") + assert(conf.get(entry) === Option(33)) + assert(conf.get(entry, Option(55)) === Option(33)) + + val entryWithDefault = SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD + assert(conf.get(entryWithDefault.key) === "10") + assert(conf.get(entryWithDefault) === 10) + assert(conf.get(entryWithDefault, 11) === 11) + conf.set(entryWithDefault, 12) + assert(conf.get(entryWithDefault.key) === "12") + assert(conf.get(entryWithDefault) === 12) + assert(conf.get(entryWithDefault, 11) === 12) + } }