diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 5f9a4411ecdcc..1ffbb8aa8814f 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -183,6 +183,97 @@ message ExecutePlanResponse { } } +// The key-value pair for the config request and response. +message KeyValue { + // (Required) The key. + string key = 1; + // (Optional) The value. + optional string value = 2; +} + +// Request to update or fetch the configurations. +message ConfigRequest { + // (Required) + // + // The client_id is set by the client to be able to collate streaming responses from + // different queries. + string client_id = 1; + + // (Required) User context + UserContext user_context = 2; + + // (Required) The operation for the config. + Operation operation = 3; + + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 4; + + message Operation { + oneof op_type { + Set set = 1; + Get get = 2; + GetWithDefault get_with_default = 3; + GetOption get_option = 4; + GetAll get_all = 5; + Unset unset = 6; + IsModifiable is_modifiable = 7; + } + } + + message Set { + // (Required) The config key-value pairs to set. + repeated KeyValue pairs = 1; + } + + message Get { + // (Required) The config keys to get. + repeated string keys = 1; + } + + message GetWithDefault { + // (Required) The config key-value paris to get. The value will be used as the default value. + repeated KeyValue pairs = 1; + } + + message GetOption { + // (Required) The config keys to get optionally. + repeated string keys = 1; + } + + message GetAll { + // (Optional) The prefix of the config key to get. + optional string prefix = 1; + } + + message Unset { + // (Required) The config keys to unset. + repeated string keys = 1; + } + + message IsModifiable { + // (Required) The config keys to check the config is modifiable. + repeated string keys = 1; + } +} + +// Response to the config request. +message ConfigResponse { + string client_id = 1; + + // (Optional) The result key-value pairs. + // + // Available when the operation is 'Get', 'GetWithDefault', 'GetOption', 'GetAll'. + // Also available for the operation 'IsModifiable' with boolean string "true" and "false". + repeated KeyValue pairs = 2; + + // (Optional) + // + // Warning messages for deprecated or unsupported configurations. + repeated string warnings = 3; +} + // Main interface for the SparkConnect service. service SparkConnectService { @@ -193,5 +284,8 @@ service SparkConnectService { // Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query. rpc AnalyzePlan(AnalyzePlanRequest) returns (AnalyzePlanResponse) {} + + // Update or fetch the configurations and returns a [[ConfigResponse]] containing the result. + rpc Config(ConfigRequest) returns (ConfigResponse) {} } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala new file mode 100644 index 0000000000000..84f625222a856 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectConfigHandler.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import scala.collection.JavaConverters._ + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RuntimeConfig +import org.apache.spark.sql.internal.SQLConf + +class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigResponse]) + extends Logging { + + def handle(request: proto.ConfigRequest): Unit = { + val session = + SparkConnectService + .getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getClientId) + .session + + val builder = request.getOperation.getOpTypeCase match { + case proto.ConfigRequest.Operation.OpTypeCase.SET => + handleSet(request.getOperation.getSet, session.conf) + case proto.ConfigRequest.Operation.OpTypeCase.GET => + handleGet(request.getOperation.getGet, session.conf) + case proto.ConfigRequest.Operation.OpTypeCase.GET_WITH_DEFAULT => + handleGetWithDefault(request.getOperation.getGetWithDefault, session.conf) + case proto.ConfigRequest.Operation.OpTypeCase.GET_OPTION => + handleGetOption(request.getOperation.getGetOption, session.conf) + case proto.ConfigRequest.Operation.OpTypeCase.GET_ALL => + handleGetAll(request.getOperation.getGetAll, session.conf) + case proto.ConfigRequest.Operation.OpTypeCase.UNSET => + handleUnset(request.getOperation.getUnset, session.conf) + case proto.ConfigRequest.Operation.OpTypeCase.IS_MODIFIABLE => + handleIsModifiable(request.getOperation.getIsModifiable, session.conf) + case _ => throw new UnsupportedOperationException(s"${request.getOperation} not supported.") + } + + builder.setClientId(request.getClientId) + responseObserver.onNext(builder.build()) + responseObserver.onCompleted() + } + + private def handleSet( + operation: proto.ConfigRequest.Set, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + operation.getPairsList.asScala.iterator.foreach { pair => + val (key, value) = SparkConnectConfigHandler.toKeyValue(pair) + conf.set(key, value.orNull) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def handleGet( + operation: proto.ConfigRequest.Get, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + operation.getKeysList.asScala.iterator.foreach { key => + val value = conf.get(key) + builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value))) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def handleGetWithDefault( + operation: proto.ConfigRequest.GetWithDefault, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + operation.getPairsList.asScala.iterator.foreach { pair => + val (key, default) = SparkConnectConfigHandler.toKeyValue(pair) + val value = conf.get(key, default.orNull) + builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value))) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def handleGetOption( + operation: proto.ConfigRequest.GetOption, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + operation.getKeysList.asScala.iterator.foreach { key => + val value = conf.getOption(key) + builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, value)) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def handleGetAll( + operation: proto.ConfigRequest.GetAll, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + val results = if (operation.hasPrefix) { + val prefix = operation.getPrefix + conf.getAll.iterator + .filter { case (key, _) => key.startsWith(prefix) } + .map { case (key, value) => (key.substring(prefix.length), value) } + } else { + conf.getAll.iterator + } + results.foreach { case (key, value) => + builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value))) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def handleUnset( + operation: proto.ConfigRequest.Unset, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + operation.getKeysList.asScala.iterator.foreach { key => + conf.unset(key) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def handleIsModifiable( + operation: proto.ConfigRequest.IsModifiable, + conf: RuntimeConfig): proto.ConfigResponse.Builder = { + val builder = proto.ConfigResponse.newBuilder() + operation.getKeysList.asScala.iterator.foreach { key => + val value = conf.isModifiable(key) + builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value.toString))) + getWarning(key).foreach(builder.addWarnings) + } + builder + } + + private def getWarning(key: String): Option[String] = { + if (SparkConnectConfigHandler.unsupportedConfigurations.contains(key)) { + Some(s"The SQL config '$key' is NOT supported in Spark Connect") + } else { + SQLConf.deprecatedSQLConfigs.get(key).map(_.toDeprecationString) + } + } +} + +object SparkConnectConfigHandler { + + private[connect] val unsupportedConfigurations = Set("spark.sql.execution.arrow.enabled") + + def toKeyValue(pair: proto.KeyValue): (String, Option[String]) = { + val key = pair.getKey + val value = if (pair.hasValue) { + Some(pair.getValue) + } else { + None + } + (key, value) + } + + def toProtoKeyValue(key: String, value: Option[String]): proto.KeyValue = { + val builder = proto.KeyValue.newBuilder() + builder.setKey(key) + value.foreach(builder.setValue) + builder.build() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 959aceaf46a38..227067e2fafb5 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -212,6 +212,21 @@ class SparkConnectService(debug: Boolean) response.setIsStreaming(ds.isStreaming) response.addAllInputFiles(ds.inputFiles.toSeq.asJava) } + + /** + * This is the main entry method for Spark Connect and all calls to update or fetch + * configuration.. + * + * @param request + * @param responseObserver + */ + override def config( + request: proto.ConfigRequest, + responseObserver: StreamObserver[proto.ConfigResponse]): Unit = { + try { + new SparkConnectConfigHandler(responseObserver).handle(request) + } catch handleError("config", observer = responseObserver) + } } /** diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 75a6b4401b86b..daecbc8485624 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -508,6 +508,7 @@ def __hash__(self): python_test_goals=[ # doctests "pyspark.sql.connect.catalog", + "pyspark.sql.connect.conf", "pyspark.sql.connect.group", "pyspark.sql.connect.session", "pyspark.sql.connect.window", @@ -523,6 +524,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_connect_column", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_catalog", + "pyspark.sql.tests.connect.test_parity_conf", "pyspark.sql.tests.connect.test_parity_serde", "pyspark.sql.tests.connect.test_parity_functions", "pyspark.sql.tests.connect.test_parity_group", diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 9deb0147e6644..c48dc8449cd75 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -473,7 +473,7 @@ def default_session() -> SparkSession: # Turn ANSI off when testing the pandas API on Spark since # the behavior of pandas API on Spark follows pandas, not SQL. if is_testing(): - spark.conf.set("spark.sql.ansi.enabled", False) # type: ignore[arg-type] + spark.conf.set("spark.sql.ansi.enabled", False) if spark.conf.get("spark.sql.ansi.enabled") == "true": log_advice( "The config 'spark.sql.ansi.enabled' is set to True. " diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 40a36a26701a6..e8b258c9bf826 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -28,6 +28,9 @@ class RuntimeConfig: """User-facing configuration API, accessible through `SparkSession.conf`. Options set here are automatically propagated to the Hadoop configuration during I/O. + + .. versionchanged:: 3.4.0 + Support Spark Connect. """ def __init__(self, jconf: JavaObject) -> None: @@ -35,14 +38,23 @@ def __init__(self, jconf: JavaObject) -> None: self._jconf = jconf @since(2.0) - def set(self, key: str, value: str) -> None: - """Sets the given Spark runtime configuration property.""" + def set(self, key: str, value: Union[str, int, bool]) -> None: + """Sets the given Spark runtime configuration property. + + .. versionchanged:: 3.4.0 + Support Spark Connect. + """ self._jconf.set(key, value) @since(2.0) - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: + def get( + self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue + ) -> Optional[str]: """Returns the value of Spark runtime configuration property for the given key, assuming it is set. + + .. versionchanged:: 3.4.0 + Support Spark Connect. """ self._checkType(key, "key") if default is _NoValue: @@ -54,7 +66,11 @@ def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) @since(2.0) def unset(self, key: str) -> None: - """Resets the configuration property for the given key.""" + """Resets the configuration property for the given key. + + .. versionchanged:: 3.4.0 + Support Spark Connect. + """ self._jconf.unset(key) def _checkType(self, obj: Any, identifier: str) -> None: @@ -68,6 +84,9 @@ def _checkType(self, obj: Any, identifier: str) -> None: def isModifiable(self, key: str) -> bool: """Indicates whether the configuration property with the given key is modifiable in the current session. + + .. versionchanged:: 3.4.0 + Support Spark Connect. """ return self._jconf.isModifiable(key) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 154dd161e9221..6ec10897fa40e 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -402,6 +402,19 @@ def fromProto(cls, pb: Any) -> "AnalyzeResult": ) +class ConfigResult: + def __init__(self, pairs: List[Tuple[str, Optional[str]]], warnings: List[str]): + self.pairs = pairs + self.warnings = warnings + + @classmethod + def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult": + return ConfigResult( + pairs=[(pair.key, pair.value if pair.HasField("value") else None) for pair in pb.pairs], + warnings=list(pb.warnings), + ) + + class SparkConnectClient(object): """ Conceptually the remote spark session that communicates with the server @@ -736,6 +749,45 @@ def _execute_and_fetch( metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else [] return table, metrics + def _config_request_with_metadata(self) -> pb2.ConfigRequest: + req = pb2.ConfigRequest() + req.client_id = self._session_id + req.client_type = self._builder.userAgent + if self._user_id: + req.user_context.user_id = self._user_id + return req + + def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult: + """ + Call the config RPC of Spark Connect. + + Parameters + ---------- + operation : str + Operation kind + + Returns + ------- + The result of the config call. + """ + req = self._config_request_with_metadata() + req.operation.CopyFrom(operation) + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + resp = self._stub.Config(req, metadata=self._builder.metadata()) + if resp.client_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request:" + f"{resp.client_id} != {self._session_id}" + ) + return ConfigResult.fromProto(resp) + raise SparkConnectException("Invalid state during retry exception handling.") + except grpc.RpcError as rpc_error: + self._handle_error(rpc_error) + def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn: """ Error handling helper for dealing with GRPC Errors. On the server side, certain diff --git a/python/pyspark/sql/connect/conf.py b/python/pyspark/sql/connect/conf.py new file mode 100644 index 0000000000000..d323de716c46a --- /dev/null +++ b/python/pyspark/sql/connect/conf.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Optional, Union, cast +import warnings + +from pyspark import _NoValue +from pyspark._globals import _NoValueType +from pyspark.sql.conf import RuntimeConfig as PySparkRuntimeConfig +from pyspark.sql.connect import proto +from pyspark.sql.connect.client import SparkConnectClient + + +class RuntimeConf: + def __init__(self, client: SparkConnectClient) -> None: + """Create a new RuntimeConfig.""" + self._client = client + + __init__.__doc__ = PySparkRuntimeConfig.__init__.__doc__ + + def set(self, key: str, value: Union[str, int, bool]) -> None: + if isinstance(value, bool): + value = "true" if value else "false" + elif isinstance(value, int): + value = str(value) + op_set = proto.ConfigRequest.Set(pairs=[proto.KeyValue(key=key, value=value)]) + operation = proto.ConfigRequest.Operation(set=op_set) + result = self._client.config(operation) + for warn in result.warnings: + warnings.warn(warn) + + set.__doc__ = PySparkRuntimeConfig.set.__doc__ + + def get( + self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue + ) -> Optional[str]: + self._checkType(key, "key") + if default is _NoValue: + op_get = proto.ConfigRequest.Get(keys=[key]) + operation = proto.ConfigRequest.Operation(get=op_get) + else: + if default is not None: + self._checkType(default, "default") + op_get_with_default = proto.ConfigRequest.GetWithDefault( + pairs=[proto.KeyValue(key=key, value=cast(Optional[str], default))] + ) + operation = proto.ConfigRequest.Operation(get_with_default=op_get_with_default) + result = self._client.config(operation) + return result.pairs[0][1] + + get.__doc__ = PySparkRuntimeConfig.get.__doc__ + + def unset(self, key: str) -> None: + op_unset = proto.ConfigRequest.Unset(keys=[key]) + operation = proto.ConfigRequest.Operation(unset=op_unset) + result = self._client.config(operation) + for warn in result.warnings: + warnings.warn(warn) + + unset.__doc__ = PySparkRuntimeConfig.unset.__doc__ + + def isModifiable(self, key: str) -> bool: + op_is_modifiable = proto.ConfigRequest.IsModifiable(keys=[key]) + operation = proto.ConfigRequest.Operation(is_modifiable=op_is_modifiable) + result = self._client.config(operation) + if result.pairs[0][1] == "true": + return True + elif result.pairs[0][1] == "false": + return False + else: + raise ValueError(f"Unknown boolean value: {result.pairs[0][1]}") + + isModifiable.__doc__ = PySparkRuntimeConfig.isModifiable.__doc__ + + def _checkType(self, obj: Any, identifier: str) -> None: + """Assert that an object is of type str.""" + if not isinstance(obj, str): + raise TypeError( + "expected %s '%s' to be a string (was '%s')" % (identifier, obj, type(obj).__name__) + ) + + +RuntimeConf.__doc__ = PySparkRuntimeConfig.__doc__ + + +def _test() -> None: + import sys + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.conf + + globs = pyspark.sql.connect.conf.__dict__.copy() + globs["spark"] = ( + PySparkSession.builder.appName("sql.connect.conf tests").remote("local[4]").getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.conf, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 7d61a86c8b5e7..87dfe90107d0e 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -2477,11 +2477,6 @@ def _test() -> None: # Spark Connect does not support Spark Context but the test depends on that. del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__ - # TODO(SPARK-41834): implement Dataframe.conf - del pyspark.sql.connect.functions.from_unixtime.__doc__ - del pyspark.sql.connect.functions.timestamp_seconds.__doc__ - del pyspark.sql.connect.functions.unix_timestamp.__doc__ - # TODO(SPARK-41843): Implement SparkSession.udf del pyspark.sql.connect.functions.call_udf.__doc__ diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 0d86ce8cd687b..95951d8f8e3b6 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x02\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString\x12\x1f\n\x0btree_string\x18\x04 \x01(\tR\ntreeString\x12\x19\n\x08is_local\x18\x05 \x01(\x08R\x07isLocal\x12!\n\x0cis_streaming\x18\x06 \x01(\x08R\x0bisStreaming\x12\x1f\n\x0binput_files\x18\x07 \x03(\tR\ninputFiles"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\x8f\x06\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12N\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchR\narrowBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType2\xc7\x01\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain.ExplainModeR\x0b\x65xplainMode"c\n\x0b\x45xplainMode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\n\n\x06SIMPLE\x10\x01\x12\x0c\n\x08\x45XTENDED\x10\x02\x12\x0b\n\x07\x43ODEGEN\x10\x03\x12\x08\n\x04\x43OST\x10\x04\x12\r\n\tFORMATTED\x10\x05"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\x81\x02\n\x12\x41nalyzePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x30\n\x07\x65xplain\x18\x05 \x01(\x0b\x32\x16.spark.connect.ExplainR\x07\x65xplainB\x0e\n\x0c_client_type"\x8a\x02\n\x13\x41nalyzePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString\x12\x1f\n\x0btree_string\x18\x04 \x01(\tR\ntreeString\x12\x19\n\x08is_local\x18\x05 \x01(\x08R\x07isLocal\x12!\n\x0cis_streaming\x18\x06 \x01(\x08R\x0bisStreaming\x12\x1f\n\x0binput_files\x18\x07 \x03(\tR\ninputFiles"\xcf\x01\n\x12\x45xecutePlanRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\x8f\x06\n\x13\x45xecutePlanResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12N\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchR\narrowBatch\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x82\x08\n\rConfigRequest\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"x\n\x0e\x43onfigResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings2\x90\x02\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -58,6 +58,17 @@ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[ "MetricValue" ] +_KEYVALUE = DESCRIPTOR.message_types_by_name["KeyValue"] +_CONFIGREQUEST = DESCRIPTOR.message_types_by_name["ConfigRequest"] +_CONFIGREQUEST_OPERATION = _CONFIGREQUEST.nested_types_by_name["Operation"] +_CONFIGREQUEST_SET = _CONFIGREQUEST.nested_types_by_name["Set"] +_CONFIGREQUEST_GET = _CONFIGREQUEST.nested_types_by_name["Get"] +_CONFIGREQUEST_GETWITHDEFAULT = _CONFIGREQUEST.nested_types_by_name["GetWithDefault"] +_CONFIGREQUEST_GETOPTION = _CONFIGREQUEST.nested_types_by_name["GetOption"] +_CONFIGREQUEST_GETALL = _CONFIGREQUEST.nested_types_by_name["GetAll"] +_CONFIGREQUEST_UNSET = _CONFIGREQUEST.nested_types_by_name["Unset"] +_CONFIGREQUEST_ISMODIFIABLE = _CONFIGREQUEST.nested_types_by_name["IsModifiable"] +_CONFIGRESPONSE = DESCRIPTOR.message_types_by_name["ConfigResponse"] _EXPLAIN_EXPLAINMODE = _EXPLAIN.enum_types_by_name["ExplainMode"] Plan = _reflection.GeneratedProtocolMessageType( "Plan", @@ -186,6 +197,119 @@ _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry) _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricValue) +KeyValue = _reflection.GeneratedProtocolMessageType( + "KeyValue", + (_message.Message,), + { + "DESCRIPTOR": _KEYVALUE, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.KeyValue) + }, +) +_sym_db.RegisterMessage(KeyValue) + +ConfigRequest = _reflection.GeneratedProtocolMessageType( + "ConfigRequest", + (_message.Message,), + { + "Operation": _reflection.GeneratedProtocolMessageType( + "Operation", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_OPERATION, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Operation) + }, + ), + "Set": _reflection.GeneratedProtocolMessageType( + "Set", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_SET, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Set) + }, + ), + "Get": _reflection.GeneratedProtocolMessageType( + "Get", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_GET, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Get) + }, + ), + "GetWithDefault": _reflection.GeneratedProtocolMessageType( + "GetWithDefault", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_GETWITHDEFAULT, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.GetWithDefault) + }, + ), + "GetOption": _reflection.GeneratedProtocolMessageType( + "GetOption", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_GETOPTION, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.GetOption) + }, + ), + "GetAll": _reflection.GeneratedProtocolMessageType( + "GetAll", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_GETALL, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.GetAll) + }, + ), + "Unset": _reflection.GeneratedProtocolMessageType( + "Unset", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_UNSET, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.Unset) + }, + ), + "IsModifiable": _reflection.GeneratedProtocolMessageType( + "IsModifiable", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGREQUEST_ISMODIFIABLE, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest.IsModifiable) + }, + ), + "DESCRIPTOR": _CONFIGREQUEST, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigRequest) + }, +) +_sym_db.RegisterMessage(ConfigRequest) +_sym_db.RegisterMessage(ConfigRequest.Operation) +_sym_db.RegisterMessage(ConfigRequest.Set) +_sym_db.RegisterMessage(ConfigRequest.Get) +_sym_db.RegisterMessage(ConfigRequest.GetWithDefault) +_sym_db.RegisterMessage(ConfigRequest.GetOption) +_sym_db.RegisterMessage(ConfigRequest.GetAll) +_sym_db.RegisterMessage(ConfigRequest.Unset) +_sym_db.RegisterMessage(ConfigRequest.IsModifiable) + +ConfigResponse = _reflection.GeneratedProtocolMessageType( + "ConfigResponse", + (_message.Message,), + { + "DESCRIPTOR": _CONFIGRESPONSE, + "__module__": "spark.connect.base_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.ConfigResponse) + }, +) +_sym_db.RegisterMessage(ConfigResponse) + _SPARKCONNECTSERVICE = DESCRIPTOR.services_by_name["SparkConnectService"] if _descriptor._USE_C_DESCRIPTORS == False: @@ -219,6 +343,28 @@ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2017 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2019 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2107 - _SPARKCONNECTSERVICE._serialized_start = 2110 - _SPARKCONNECTSERVICE._serialized_end = 2309 + _KEYVALUE._serialized_start = 2109 + _KEYVALUE._serialized_end = 2174 + _CONFIGREQUEST._serialized_start = 2177 + _CONFIGREQUEST._serialized_end = 3203 + _CONFIGREQUEST_OPERATION._serialized_start = 2395 + _CONFIGREQUEST_OPERATION._serialized_end = 2893 + _CONFIGREQUEST_SET._serialized_start = 2895 + _CONFIGREQUEST_SET._serialized_end = 2947 + _CONFIGREQUEST_GET._serialized_start = 2949 + _CONFIGREQUEST_GET._serialized_end = 2974 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 2976 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 3039 + _CONFIGREQUEST_GETOPTION._serialized_start = 3041 + _CONFIGREQUEST_GETOPTION._serialized_end = 3072 + _CONFIGREQUEST_GETALL._serialized_start = 3074 + _CONFIGREQUEST_GETALL._serialized_end = 3122 + _CONFIGREQUEST_UNSET._serialized_start = 3124 + _CONFIGREQUEST_UNSET._serialized_end = 3151 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 3153 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 3187 + _CONFIGRESPONSE._serialized_start = 3205 + _CONFIGRESPONSE._serialized_end = 3325 + _SPARKCONNECTSERVICE._serialized_start = 3328 + _SPARKCONNECTSERVICE._serialized_end = 3600 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index ea82aaf21e252..f6c402b229f9e 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -570,3 +570,345 @@ class ExecutePlanResponse(google.protobuf.message.Message): ) -> None: ... global___ExecutePlanResponse = ExecutePlanResponse + +class KeyValue(google.protobuf.message.Message): + """The key-value pair for the config request and response.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + """(Required) The key.""" + value: builtins.str + """(Optional) The value.""" + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["_value", b"_value", "value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_value", b"_value", "key", b"key", "value", b"value" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_value", b"_value"] + ) -> typing_extensions.Literal["value"] | None: ... + +global___KeyValue = KeyValue + +class ConfigRequest(google.protobuf.message.Message): + """Request to update or fetch the configurations.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class Operation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SET_FIELD_NUMBER: builtins.int + GET_FIELD_NUMBER: builtins.int + GET_WITH_DEFAULT_FIELD_NUMBER: builtins.int + GET_OPTION_FIELD_NUMBER: builtins.int + GET_ALL_FIELD_NUMBER: builtins.int + UNSET_FIELD_NUMBER: builtins.int + IS_MODIFIABLE_FIELD_NUMBER: builtins.int + @property + def set(self) -> global___ConfigRequest.Set: ... + @property + def get(self) -> global___ConfigRequest.Get: ... + @property + def get_with_default(self) -> global___ConfigRequest.GetWithDefault: ... + @property + def get_option(self) -> global___ConfigRequest.GetOption: ... + @property + def get_all(self) -> global___ConfigRequest.GetAll: ... + @property + def unset(self) -> global___ConfigRequest.Unset: ... + @property + def is_modifiable(self) -> global___ConfigRequest.IsModifiable: ... + def __init__( + self, + *, + set: global___ConfigRequest.Set | None = ..., + get: global___ConfigRequest.Get | None = ..., + get_with_default: global___ConfigRequest.GetWithDefault | None = ..., + get_option: global___ConfigRequest.GetOption | None = ..., + get_all: global___ConfigRequest.GetAll | None = ..., + unset: global___ConfigRequest.Unset | None = ..., + is_modifiable: global___ConfigRequest.IsModifiable | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "get", + b"get", + "get_all", + b"get_all", + "get_option", + b"get_option", + "get_with_default", + b"get_with_default", + "is_modifiable", + b"is_modifiable", + "op_type", + b"op_type", + "set", + b"set", + "unset", + b"unset", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "get", + b"get", + "get_all", + b"get_all", + "get_option", + b"get_option", + "get_with_default", + b"get_with_default", + "is_modifiable", + b"is_modifiable", + "op_type", + b"op_type", + "set", + b"set", + "unset", + b"unset", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["op_type", b"op_type"] + ) -> typing_extensions.Literal[ + "set", "get", "get_with_default", "get_option", "get_all", "unset", "is_modifiable" + ] | None: ... + + class Set(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PAIRS_FIELD_NUMBER: builtins.int + @property + def pairs( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___KeyValue]: + """(Required) The config key-value pairs to set.""" + def __init__( + self, + *, + pairs: collections.abc.Iterable[global___KeyValue] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["pairs", b"pairs"]) -> None: ... + + class Get(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + @property + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Required) The config keys to get.""" + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ... + + class GetWithDefault(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PAIRS_FIELD_NUMBER: builtins.int + @property + def pairs( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___KeyValue]: + """(Required) The config key-value paris to get. The value will be used as the default value.""" + def __init__( + self, + *, + pairs: collections.abc.Iterable[global___KeyValue] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["pairs", b"pairs"]) -> None: ... + + class GetOption(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + @property + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Required) The config keys to get optionally.""" + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ... + + class GetAll(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + PREFIX_FIELD_NUMBER: builtins.int + prefix: builtins.str + """(Optional) The prefix of the config key to get.""" + def __init__( + self, + *, + prefix: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["_prefix", b"_prefix", "prefix", b"prefix"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["_prefix", b"_prefix", "prefix", b"prefix"] + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_prefix", b"_prefix"] + ) -> typing_extensions.Literal["prefix"] | None: ... + + class Unset(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + @property + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Required) The config keys to unset.""" + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ... + + class IsModifiable(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEYS_FIELD_NUMBER: builtins.int + @property + def keys( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Required) The config keys to check the config is modifiable.""" + def __init__( + self, + *, + keys: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["keys", b"keys"]) -> None: ... + + CLIENT_ID_FIELD_NUMBER: builtins.int + USER_CONTEXT_FIELD_NUMBER: builtins.int + OPERATION_FIELD_NUMBER: builtins.int + CLIENT_TYPE_FIELD_NUMBER: builtins.int + client_id: builtins.str + """(Required) + + The client_id is set by the client to be able to collate streaming responses from + different queries. + """ + @property + def user_context(self) -> global___UserContext: + """(Required) User context""" + @property + def operation(self) -> global___ConfigRequest.Operation: + """(Required) The operation for the config.""" + client_type: builtins.str + """Provides optional information about the client sending the request. This field + can be used for language or version specific information and is only intended for + logging purposes and will not be interpreted by the server. + """ + def __init__( + self, + *, + client_id: builtins.str = ..., + user_context: global___UserContext | None = ..., + operation: global___ConfigRequest.Operation | None = ..., + client_type: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "operation", + b"operation", + "user_context", + b"user_context", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_id", + b"client_id", + "client_type", + b"client_type", + "operation", + b"operation", + "user_context", + b"user_context", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] + ) -> typing_extensions.Literal["client_type"] | None: ... + +global___ConfigRequest = ConfigRequest + +class ConfigResponse(google.protobuf.message.Message): + """Response to the config request.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_ID_FIELD_NUMBER: builtins.int + PAIRS_FIELD_NUMBER: builtins.int + WARNINGS_FIELD_NUMBER: builtins.int + client_id: builtins.str + @property + def pairs( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___KeyValue]: + """(Optional) The result key-value pairs. + + Available when the operation is 'Get', 'GetWithDefault', 'GetOption', 'GetAll'. + Also available for the operation 'IsModifiable' with boolean string "true" and "false". + """ + @property + def warnings( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """(Optional) + + Warning messages for deprecated or unsupported configurations. + """ + def __init__( + self, + *, + client_id: builtins.str = ..., + pairs: collections.abc.Iterable[global___KeyValue] | None = ..., + warnings: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "client_id", b"client_id", "pairs", b"pairs", "warnings", b"warnings" + ], + ) -> None: ... + +global___ConfigResponse = ConfigResponse diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index aff5897f520f8..007e31fd0ea3b 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -40,6 +40,11 @@ def __init__(self, channel): request_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.FromString, ) + self.Config = channel.unary_unary( + "/spark.connect.SparkConnectService/Config", + request_serializer=spark_dot_connect_dot_base__pb2.ConfigRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.ConfigResponse.FromString, + ) class SparkConnectServiceServicer(object): @@ -60,6 +65,12 @@ def AnalyzePlan(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def Config(self, request, context): + """Update or fetch the configurations and returns a [[ConfigResponse]] containing the result.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_SparkConnectServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -73,6 +84,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): request_deserializer=spark_dot_connect_dot_base__pb2.AnalyzePlanRequest.FromString, response_serializer=spark_dot_connect_dot_base__pb2.AnalyzePlanResponse.SerializeToString, ), + "Config": grpc.unary_unary_rpc_method_handler( + servicer.Config, + request_deserializer=spark_dot_connect_dot_base__pb2.ConfigRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.ConfigResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "spark.connect.SparkConnectService", rpc_method_handlers @@ -141,3 +157,32 @@ def AnalyzePlan( timeout, metadata, ) + + @staticmethod + def Config( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/spark.connect.SparkConnectService/Config", + spark_dot_connect_dot_base__pb2.ConfigRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.ConfigResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 08e63f544e23d..c95279a8c8e4c 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -47,6 +47,7 @@ from pyspark import SparkContext, SparkConf, __version__ from pyspark.sql.connect.client import SparkConnectClient +from pyspark.sql.connect.conf import RuntimeConf from pyspark.sql.connect.dataframe import DataFrame from pyspark.sql.connect.plan import SQL, Range, LocalRelation from pyspark.sql.connect.readwriter import DataFrameReader @@ -421,8 +422,8 @@ def newSession(self) -> Any: raise NotImplementedError("newSession() is not implemented.") @property - def conf(self) -> Any: - raise NotImplementedError("conf() is not implemented.") + def conf(self) -> RuntimeConf: + return RuntimeConf(self.client) @property def sparkContext(self) -> Any: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 3c47ebfb97365..99f97977ccc7c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -191,9 +191,11 @@ def setConf(self, key: str, value: Union[bool, int, str]) -> None: .. versionadded:: 1.3.0 """ - self.sparkSession.conf.set(key, value) # type: ignore[arg-type] + self.sparkSession.conf.set(key, value) - def getConf(self, key: str, defaultValue: Union[Optional[str], _NoValueType] = _NoValue) -> str: + def getConf( + self, key: str, defaultValue: Union[Optional[str], _NoValueType] = _NoValue + ) -> Optional[str]: """Returns the value of Spark SQL configuration property for the given key. If the key is not set and defaultValue is set, return diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index adcd457a10515..84c3e4f23a6ab 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2796,7 +2796,6 @@ def test_unsupported_session_functions(self): for f in ( "newSession", - "conf", "sparkContext", "streams", "readStream", diff --git a/python/pyspark/sql/tests/connect/test_parity_conf.py b/python/pyspark/sql/tests/connect/test_parity_conf.py new file mode 100644 index 0000000000000..554f05f27ea77 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_conf.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.tests.test_conf import ConfTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class ConfParityTests(ConfTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_conf import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 07cae0fb27d0d..60f6e78024604 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -46,11 +46,6 @@ def test_help_command(self): def test_invalid_join_method(self): super().test_invalid_join_method() - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_join_without_on(self): - super().test_join_without_on() - # TODO(SPARK-41527): Implement DataFrame.observe @unittest.skip("Fails in Spark Connect, should enable.") def test_observe(self): @@ -75,11 +70,6 @@ def test_repartitionByRange_dataframe(self): def test_repr_behaviors(self): super().test_repr_behaviors() - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_require_cross(self): - super().test_require_cross() - # TODO(SPARK-41874): Implement DataFrame `sameSemantics` @unittest.skip("Fails in Spark Connect, should enable.") def test_same_semantics_error(self): @@ -117,16 +107,6 @@ def test_to_pandas_for_array_of_struct(self): # Spark Connect's implementation is based on Arrow. super().check_to_pandas_for_array_of_struct(True) - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_to_pandas_from_empty_dataframe(self): - super().test_to_pandas_from_empty_dataframe() - - # TODO(SPARK-41834): Implement SparkSession.conf - @unittest.skip("Fails in Spark Connect, should enable.") - def test_to_pandas_from_mixed_dataframe(self): - super().test_to_pandas_from_mixed_dataframe() - # TODO(SPARK-41834): Implement SparkSession.conf @unittest.skip("Fails in Spark Connect, should enable.") def test_to_pandas_from_null_dataframe(self): diff --git a/python/pyspark/sql/tests/test_conf.py b/python/pyspark/sql/tests/test_conf.py index a8fa59c036408..15722c2c57a40 100644 --- a/python/pyspark/sql/tests/test_conf.py +++ b/python/pyspark/sql/tests/test_conf.py @@ -14,11 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from decimal import Decimal +from pyspark.errors import IllegalArgumentException from pyspark.testing.sqlutils import ReusedSQLTestCase -class ConfTests(ReusedSQLTestCase): +class ConfTestsMixin: def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") @@ -42,6 +44,31 @@ def test_conf(self): # `defaultValue` in `spark.conf.get` is set to None. self.assertEqual(spark.conf.get("spark.sql.sources.partitionOverwriteMode", None), None) + self.assertTrue(spark.conf.isModifiable("spark.sql.execution.arrow.maxRecordsPerBatch")) + self.assertFalse(spark.conf.isModifiable("spark.sql.warehouse.dir")) + + def test_conf_with_python_objects(self): + spark = self.spark + + for value, expected in [(True, "true"), (False, "false")]: + spark.conf.set("foo", value) + self.assertEqual(spark.conf.get("foo"), expected) + + spark.conf.set("foo", 1) + self.assertEqual(spark.conf.get("foo"), "1") + + with self.assertRaises(IllegalArgumentException): + spark.conf.set("foo", None) + + with self.assertRaises(Exception): + spark.conf.set("foo", Decimal(1)) + + spark.conf.unset("foo") + + +class ConfTests(ConfTestsMixin, ReusedSQLTestCase): + pass + if __name__ == "__main__": import unittest diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d015e7df32b37..67a3f1b5fed07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4177,7 +4177,12 @@ object SQLConf { * @param comment Additional info regarding to the removed config. For example, * reasons of config deprecation, what users should use instead of it. */ - case class DeprecatedConfig(key: String, version: String, comment: String) + case class DeprecatedConfig(key: String, version: String, comment: String) { + def toDeprecationString: String = { + s"The SQL config '$key' has been deprecated in Spark v$version " + + s"and may be removed in the future. $comment" + } + } /** * Maps deprecated SQL config keys to information about the deprecation. @@ -5148,11 +5153,8 @@ class SQLConf extends Serializable with Logging { * Logs a warning message if the given config key is deprecated. */ private def logDeprecationWarning(key: String): Unit = { - SQLConf.deprecatedSQLConfigs.get(key).foreach { - case DeprecatedConfig(configName, version, comment) => - logWarning( - s"The SQL config '$configName' has been deprecated in Spark v$version " + - s"and may be removed in the future. $comment") + SQLConf.deprecatedSQLConfigs.get(key).foreach { config => + logWarning(config.toDeprecationString) } }