Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto";
// produce a relational result.
message Command {
oneof command_type {
CommonInlineUserDefinedFunction register_function = 1;
WriteOperation write_operation = 2;
CreateDataFrameViewCommand create_dataframe_view = 3;
WriteOperationV2 write_operation_v2 = 4;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.internal.CatalogImpl
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1399,6 +1400,8 @@ class SparkConnectPlanner(val session: SparkSession) {

def process(command: proto.Command): Unit = {
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
handleRegisterUserDefinedFunction(command.getRegisterFunction)
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
handleWriteOperation(command.getWriteOperation)
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
Expand All @@ -1411,6 +1414,36 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}

private def handleRegisterUserDefinedFunction(
fun: proto.CommonInlineUserDefinedFunction): Unit = {
fun.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
handleRegisterPythonUDF(fun)
case _ =>
throw InvalidPlanInput(
s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
}
}

private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
val udf = fun.getPythonUdf
val function = transformPythonFunction(udf)
val udpf = UserDefinedPythonFunction(
name = fun.getFunctionName,
func = function,
dataType = DataType.parseTypeWithFallback(
schema = udf.getOutputType,
parser = DataType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: DataType => s
case other => throw InvalidPlanInput(s"Invalid return type $other")
},
pythonEvalType = udf.getEvalType,
udfDeterministic = fun.getDeterministic)

session.udf.registerPython(fun.getFunctionName, udpf)
}

private def handleCommandPlugin(extension: ProtoAny): Unit = {
SparkConnectPluginRegistry.commandRegistry
// Lazily traverse the collection.
Expand Down
59 changes: 59 additions & 0 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import urllib.parse
import uuid
import json
import sys
from types import TracebackType
from typing import (
Iterable,
Expand Down Expand Up @@ -67,11 +68,18 @@
TempTableAlreadyExistsException,
IllegalArgumentException,
)
from pyspark.sql.connect.expressions import (
PythonUDF,
CommonInlineUserDefinedFunction,
)
from pyspark.sql.types import (
DataType,
StructType,
StructField,
)
from pyspark.sql.utils import is_remote
from pyspark.serializers import CloudPickleSerializer
from pyspark.rdd import PythonEvalType


def _configure_logging() -> logging.Logger:
Expand Down Expand Up @@ -428,6 +436,57 @@ def __init__(
self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
# Configure logging for the SparkConnect client.

def register_udf(
self,
function: Any,
return_type: Union[str, DataType],
name: Optional[str] = None,
eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
deterministic: bool = True,
) -> str:
"""Create a temporary UDF in the session catalog on the other side. We generate a
temporary name for it."""

from pyspark.sql import SparkSession as PySparkSession

if name is None:
name = f"fun_{uuid.uuid4().hex}"

# convert str return_type to DataType
if isinstance(return_type, str):

assert is_remote()
return_type_schema = ( # a workaround to parse the DataType from DDL strings
PySparkSession.builder.getOrCreate()
.createDataFrame(data=[], schema=return_type)
.schema
)
assert len(return_type_schema.fields) == 1, "returnType should be singular"
return_type = return_type_schema.fields[0].dataType

# construct a PythonUDF
py_udf = PythonUDF(
output_type=return_type.json(),
eval_type=eval_type,
command=CloudPickleSerializer().dumps((function, return_type)),
python_ver="%d.%d" % sys.version_info[:2],
)

# construct a CommonInlineUserDefinedFunction
fun = CommonInlineUserDefinedFunction(
function_name=name,
deterministic=deterministic,
arguments=[],
function=py_udf,
).to_command(self)

# construct the request
req = self._execute_plan_request_with_metadata()
req.plan.command.register_function.CopyFrom(fun)

self._execute(req)
return name

def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,13 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
)
return expr

def to_command(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction":
expr = proto.CommonInlineUserDefinedFunction()
expr.function_name = self._function_name
expr.deterministic = self._deterministic
expr.python_udf.CopyFrom(self._function.to_plan(session))
return expr

def __repr__(self) -> str:
return (
f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])}), "
Expand Down
40 changes: 20 additions & 20 deletions python/pyspark/sql/connect/proto/commands_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcc\x02\n\x07\x43ommand\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xe6\x05\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12\x1f\n\ntable_name\x18\x04 \x01(\tH\x00R\ttableName\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_type"\x9b\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1a\n\x08provider\x18\x03 \x01(\tR\x08provider\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xe6\x05\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12\x1f\n\ntable_name\x18\x04 \x01(\tH\x00R\ttableName\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_type"\x9b\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1a\n\x08provider\x18\x03 \x01(\tR\x08provider\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)


Expand Down Expand Up @@ -147,23 +147,23 @@
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
_COMMAND._serialized_start = 166
_COMMAND._serialized_end = 498
_CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 501
_CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 651
_WRITEOPERATION._serialized_start = 654
_WRITEOPERATION._serialized_end = 1396
_WRITEOPERATION_OPTIONSENTRY._serialized_start = 1092
_WRITEOPERATION_OPTIONSENTRY._serialized_end = 1150
_WRITEOPERATION_BUCKETBY._serialized_start = 1152
_WRITEOPERATION_BUCKETBY._serialized_end = 1243
_WRITEOPERATION_SAVEMODE._serialized_start = 1246
_WRITEOPERATION_SAVEMODE._serialized_end = 1383
_WRITEOPERATIONV2._serialized_start = 1399
_WRITEOPERATIONV2._serialized_end = 2194
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1092
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1150
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 1966
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2032
_WRITEOPERATIONV2_MODE._serialized_start = 2035
_WRITEOPERATIONV2_MODE._serialized_end = 2194
_COMMAND._serialized_end = 593
_CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
_CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
_WRITEOPERATION._serialized_start = 749
_WRITEOPERATION._serialized_end = 1491
_WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187
_WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245
_WRITEOPERATION_BUCKETBY._serialized_start = 1247
_WRITEOPERATION_BUCKETBY._serialized_end = 1338
_WRITEOPERATION_SAVEMODE._serialized_start = 1341
_WRITEOPERATION_SAVEMODE._serialized_end = 1478
_WRITEOPERATIONV2._serialized_start = 1494
_WRITEOPERATIONV2._serialized_end = 2289
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127
_WRITEOPERATIONV2_MODE._serialized_start = 2130
_WRITEOPERATIONV2_MODE._serialized_end = 2289
# @@protoc_insertion_point(module_scope)
17 changes: 16 additions & 1 deletion python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,16 @@ class Command(google.protobuf.message.Message):

DESCRIPTOR: google.protobuf.descriptor.Descriptor

REGISTER_FUNCTION_FIELD_NUMBER: builtins.int
WRITE_OPERATION_FIELD_NUMBER: builtins.int
CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def register_function(
self,
) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: ...
@property
def write_operation(self) -> global___WriteOperation: ...
@property
def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ...
Expand All @@ -77,6 +82,8 @@ class Command(google.protobuf.message.Message):
def __init__(
self,
*,
register_function: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
| None = ...,
write_operation: global___WriteOperation | None = ...,
create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
write_operation_v2: global___WriteOperationV2 | None = ...,
Expand All @@ -91,6 +98,8 @@ class Command(google.protobuf.message.Message):
b"create_dataframe_view",
"extension",
b"extension",
"register_function",
b"register_function",
"write_operation",
b"write_operation",
"write_operation_v2",
Expand All @@ -106,6 +115,8 @@ class Command(google.protobuf.message.Message):
b"create_dataframe_view",
"extension",
b"extension",
"register_function",
b"register_function",
"write_operation",
b"write_operation",
"write_operation_v2",
Expand All @@ -115,7 +126,11 @@ class Command(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command_type", b"command_type"]
) -> typing_extensions.Literal[
"write_operation", "create_dataframe_view", "write_operation_v2", "extension"
"register_function",
"write_operation",
"create_dataframe_view",
"write_operation_v2",
"extension",
] | None: ...

global___Command = Command
Expand Down
9 changes: 7 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.catalog import Catalog
from pyspark.sql.connect.udf import UDFRegistration


class SparkSession:
Expand Down Expand Up @@ -436,8 +437,12 @@ def readStream(self) -> Any:
raise NotImplementedError("readStream() is not implemented.")

@property
def udf(self) -> Any:
raise NotImplementedError("udf() is not implemented.")
def udf(self) -> "UDFRegistration":
from pyspark.sql.connect.udf import UDFRegistration

return UDFRegistration(self)

udf.__doc__ = PySparkSession.udf.__doc__

@property
def version(self) -> str:
Expand Down
Loading