Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,97 @@ message ExecutePlanResponse {
}
}

// The key-value pair for the config request and response.
message KeyValue {
// (Required) The key.
string key = 1;
// (Optional) The value.
optional string value = 2;
}

// Request to update or fetch the configurations.
message ConfigRequest {
// (Required)
//
// The client_id is set by the client to be able to collate streaming responses from
// different queries.
string client_id = 1;

// (Required) User context
UserContext user_context = 2;

// (Required) The operation for the config.
Operation operation = 3;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 4;

message Operation {
oneof op_type {
Set set = 1;
Get get = 2;
GetWithDefault get_with_default = 3;
GetOption get_option = 4;
GetAll get_all = 5;
Unset unset = 6;
IsModifiable is_modifiable = 7;
}
}

message Set {
// (Required) The config key-value pairs to set.
repeated KeyValue pairs = 1;
}

message Get {
// (Required) The config keys to get.
repeated string keys = 1;
}

message GetWithDefault {
// (Required) The config key-value paris to get. The value will be used as the default value.
repeated KeyValue pairs = 1;
}

message GetOption {
// (Required) The config keys to get optionally.
repeated string keys = 1;
}

message GetAll {
// (Optional) The prefix of the config key to get.
optional string prefix = 1;
}

message Unset {
// (Required) The config keys to unset.
repeated string keys = 1;
}

message IsModifiable {
// (Required) The config keys to check the config is modifiable.
repeated string keys = 1;
}
}

// Response to the config request.
message ConfigResponse {
string client_id = 1;

// (Optional) The result key-value pairs.
//
// Available when the operation is 'Get', 'GetWithDefault', 'GetOption', 'GetAll'.
// Also available for the operation 'IsModifiable' with boolean string "true" and "false".
repeated KeyValue pairs = 2;

// (Optional)
//
// Warning messages for deprecated or unsupported configurations.
repeated string warnings = 3;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

Expand All @@ -193,5 +284,8 @@ service SparkConnectService {

// Analyzes a query and returns a [[AnalyzeResponse]] containing metadata about the query.
rpc AnalyzePlan(AnalyzePlanRequest) returns (AnalyzePlanResponse) {}

// Update or fetch the configurations and returns a [[ConfigResponse]] containing the result.
rpc Config(ConfigRequest) returns (ConfigResponse) {}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import scala.collection.JavaConverters._

import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.internal.SQLConf

class SparkConnectConfigHandler(responseObserver: StreamObserver[proto.ConfigResponse])
extends Logging {

def handle(request: proto.ConfigRequest): Unit = {
val session =
SparkConnectService
.getOrCreateIsolatedSession(request.getUserContext.getUserId, request.getClientId)
.session

val builder = request.getOperation.getOpTypeCase match {
case proto.ConfigRequest.Operation.OpTypeCase.SET =>
handleSet(request.getOperation.getSet, session.conf)
case proto.ConfigRequest.Operation.OpTypeCase.GET =>
handleGet(request.getOperation.getGet, session.conf)
case proto.ConfigRequest.Operation.OpTypeCase.GET_WITH_DEFAULT =>
handleGetWithDefault(request.getOperation.getGetWithDefault, session.conf)
case proto.ConfigRequest.Operation.OpTypeCase.GET_OPTION =>
handleGetOption(request.getOperation.getGetOption, session.conf)
case proto.ConfigRequest.Operation.OpTypeCase.GET_ALL =>
handleGetAll(request.getOperation.getGetAll, session.conf)
case proto.ConfigRequest.Operation.OpTypeCase.UNSET =>
handleUnset(request.getOperation.getUnset, session.conf)
case proto.ConfigRequest.Operation.OpTypeCase.IS_MODIFIABLE =>
handleIsModifiable(request.getOperation.getIsModifiable, session.conf)
case _ => throw new UnsupportedOperationException(s"${request.getOperation} not supported.")
}

builder.setClientId(request.getClientId)
responseObserver.onNext(builder.build())
responseObserver.onCompleted()
}

private def handleSet(
operation: proto.ConfigRequest.Set,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
operation.getPairsList.asScala.iterator.foreach { pair =>
val (key, value) = SparkConnectConfigHandler.toKeyValue(pair)
conf.set(key, value.orNull)
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def handleGet(
operation: proto.ConfigRequest.Get,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
operation.getKeysList.asScala.iterator.foreach { key =>
val value = conf.get(key)
builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value)))
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def handleGetWithDefault(
operation: proto.ConfigRequest.GetWithDefault,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
operation.getPairsList.asScala.iterator.foreach { pair =>
val (key, default) = SparkConnectConfigHandler.toKeyValue(pair)
val value = conf.get(key, default.orNull)
builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value)))
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def handleGetOption(
operation: proto.ConfigRequest.GetOption,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
operation.getKeysList.asScala.iterator.foreach { key =>
val value = conf.getOption(key)
builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, value))
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def handleGetAll(
operation: proto.ConfigRequest.GetAll,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
val results = if (operation.hasPrefix) {
val prefix = operation.getPrefix
conf.getAll.iterator
.filter { case (key, _) => key.startsWith(prefix) }
.map { case (key, value) => (key.substring(prefix.length), value) }
} else {
conf.getAll.iterator
}
results.foreach { case (key, value) =>
builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value)))
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def handleUnset(
operation: proto.ConfigRequest.Unset,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
operation.getKeysList.asScala.iterator.foreach { key =>
conf.unset(key)
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def handleIsModifiable(
operation: proto.ConfigRequest.IsModifiable,
conf: RuntimeConfig): proto.ConfigResponse.Builder = {
val builder = proto.ConfigResponse.newBuilder()
operation.getKeysList.asScala.iterator.foreach { key =>
val value = conf.isModifiable(key)
builder.addPairs(SparkConnectConfigHandler.toProtoKeyValue(key, Option(value.toString)))
getWarning(key).foreach(builder.addWarnings)
}
builder
}

private def getWarning(key: String): Option[String] = {
if (SparkConnectConfigHandler.unsupportedConfigurations.contains(key)) {
Some(s"The SQL config '$key' is NOT supported in Spark Connect")
} else {
SQLConf.deprecatedSQLConfigs.get(key).map(_.toDeprecationString)
}
}
}

object SparkConnectConfigHandler {

private[connect] val unsupportedConfigurations = Set("spark.sql.execution.arrow.enabled")

def toKeyValue(pair: proto.KeyValue): (String, Option[String]) = {
val key = pair.getKey
val value = if (pair.hasValue) {
Some(pair.getValue)
} else {
None
}
(key, value)
}

def toProtoKeyValue(key: String, value: Option[String]): proto.KeyValue = {
val builder = proto.KeyValue.newBuilder()
builder.setKey(key)
value.foreach(builder.setValue)
builder.build()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,21 @@ class SparkConnectService(debug: Boolean)
response.setIsStreaming(ds.isStreaming)
response.addAllInputFiles(ds.inputFiles.toSeq.asJava)
}

/**
* This is the main entry method for Spark Connect and all calls to update or fetch
* configuration..
*
* @param request
* @param responseObserver
*/
override def config(
request: proto.ConfigRequest,
responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
try {
new SparkConnectConfigHandler(responseObserver).handle(request)
} catch handleError("config", observer = responseObserver)
}
}

/**
Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def __hash__(self):
python_test_goals=[
# doctests
"pyspark.sql.connect.catalog",
"pyspark.sql.connect.conf",
"pyspark.sql.connect.group",
"pyspark.sql.connect.session",
"pyspark.sql.connect.window",
Expand All @@ -523,6 +524,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_column",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_catalog",
"pyspark.sql.tests.connect.test_parity_conf",
"pyspark.sql.tests.connect.test_parity_serde",
"pyspark.sql.tests.connect.test_parity_functions",
"pyspark.sql.tests.connect.test_parity_group",
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def default_session() -> SparkSession:
# Turn ANSI off when testing the pandas API on Spark since
# the behavior of pandas API on Spark follows pandas, not SQL.
if is_testing():
spark.conf.set("spark.sql.ansi.enabled", False) # type: ignore[arg-type]
spark.conf.set("spark.sql.ansi.enabled", False)
if spark.conf.get("spark.sql.ansi.enabled") == "true":
log_advice(
"The config 'spark.sql.ansi.enabled' is set to True. "
Expand Down
27 changes: 23 additions & 4 deletions python/pyspark/sql/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,33 @@ class RuntimeConfig:
"""User-facing configuration API, accessible through `SparkSession.conf`.

Options set here are automatically propagated to the Hadoop configuration during I/O.

.. versionchanged:: 3.4.0
Support Spark Connect.
"""

def __init__(self, jconf: JavaObject) -> None:
"""Create a new RuntimeConfig that wraps the underlying JVM object."""
self._jconf = jconf

@since(2.0)
def set(self, key: str, value: str) -> None:
"""Sets the given Spark runtime configuration property."""
def set(self, key: str, value: Union[str, int, bool]) -> None:
"""Sets the given Spark runtime configuration property.

.. versionchanged:: 3.4.0
Support Spark Connect.
"""
self._jconf.set(key, value)

@since(2.0)
def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str:
def get(
self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue
) -> Optional[str]:
"""Returns the value of Spark runtime configuration property for the given key,
assuming it is set.

.. versionchanged:: 3.4.0
Support Spark Connect.
"""
self._checkType(key, "key")
if default is _NoValue:
Expand All @@ -54,7 +66,11 @@ def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue)

@since(2.0)
def unset(self, key: str) -> None:
"""Resets the configuration property for the given key."""
"""Resets the configuration property for the given key.

.. versionchanged:: 3.4.0
Support Spark Connect.
"""
self._jconf.unset(key)

def _checkType(self, obj: Any, identifier: str) -> None:
Expand All @@ -68,6 +84,9 @@ def _checkType(self, obj: Any, identifier: str) -> None:
def isModifiable(self, key: str) -> bool:
"""Indicates whether the configuration property with the given key
is modifiable in the current session.

.. versionchanged:: 3.4.0
Support Spark Connect.
"""
return self._jconf.isModifiable(key)

Expand Down
Loading