diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 05c91d2c992dc..7321869757758 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -31,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto"; // produce a relational result. message Command { oneof command_type { + CommonInlineUserDefinedFunction register_function = 1; WriteOperation write_operation = 2; CreateDataFrameViewCommand create_dataframe_view = 3; WriteOperationV2 write_operation_v2 = 4; diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index c8a0860b871d9..3bf5d2b1d30bf 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.functions.{col, expr} import org.apache.spark.sql.internal.CatalogImpl import org.apache.spark.sql.types._ @@ -1399,6 +1400,8 @@ class SparkConnectPlanner(val session: SparkSession) { def process(command: proto.Command): Unit = { command.getCommandTypeCase match { + case proto.Command.CommandTypeCase.REGISTER_FUNCTION => + handleRegisterUserDefinedFunction(command.getRegisterFunction) case proto.Command.CommandTypeCase.WRITE_OPERATION => handleWriteOperation(command.getWriteOperation) case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => @@ -1411,6 +1414,36 @@ class SparkConnectPlanner(val session: SparkSession) { } } + private def handleRegisterUserDefinedFunction( + fun: proto.CommonInlineUserDefinedFunction): Unit = { + fun.getFunctionCase match { + case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + handleRegisterPythonUDF(fun) + case _ => + throw InvalidPlanInput( + s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") + } + } + + private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = { + val udf = fun.getPythonUdf + val function = transformPythonFunction(udf) + val udpf = UserDefinedPythonFunction( + name = fun.getFunctionName, + func = function, + dataType = DataType.parseTypeWithFallback( + schema = udf.getOutputType, + parser = DataType.fromDDL, + fallbackParser = DataType.fromJson) match { + case s: DataType => s + case other => throw InvalidPlanInput(s"Invalid return type $other") + }, + pythonEvalType = udf.getEvalType, + udfDeterministic = fun.getDeterministic) + + session.udf.registerPython(fun.getFunctionName, udpf) + } + private def handleCommandPlugin(extension: ProtoAny): Unit = { SparkConnectPluginRegistry.commandRegistry // Lazily traverse the collection. diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 8cf5fa5069357..903981a015bed 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -30,6 +30,7 @@ import urllib.parse import uuid import json +import sys from types import TracebackType from typing import ( Iterable, @@ -67,11 +68,18 @@ TempTableAlreadyExistsException, IllegalArgumentException, ) +from pyspark.sql.connect.expressions import ( + PythonUDF, + CommonInlineUserDefinedFunction, +) from pyspark.sql.types import ( DataType, StructType, StructField, ) +from pyspark.sql.utils import is_remote +from pyspark.serializers import CloudPickleSerializer +from pyspark.rdd import PythonEvalType def _configure_logging() -> logging.Logger: @@ -428,6 +436,57 @@ def __init__( self._stub = grpc_lib.SparkConnectServiceStub(self._channel) # Configure logging for the SparkConnect client. + def register_udf( + self, + function: Any, + return_type: Union[str, DataType], + name: Optional[str] = None, + eval_type: int = PythonEvalType.SQL_BATCHED_UDF, + deterministic: bool = True, + ) -> str: + """Create a temporary UDF in the session catalog on the other side. We generate a + temporary name for it.""" + + from pyspark.sql import SparkSession as PySparkSession + + if name is None: + name = f"fun_{uuid.uuid4().hex}" + + # convert str return_type to DataType + if isinstance(return_type, str): + + assert is_remote() + return_type_schema = ( # a workaround to parse the DataType from DDL strings + PySparkSession.builder.getOrCreate() + .createDataFrame(data=[], schema=return_type) + .schema + ) + assert len(return_type_schema.fields) == 1, "returnType should be singular" + return_type = return_type_schema.fields[0].dataType + + # construct a PythonUDF + py_udf = PythonUDF( + output_type=return_type.json(), + eval_type=eval_type, + command=CloudPickleSerializer().dumps((function, return_type)), + python_ver="%d.%d" % sys.version_info[:2], + ) + + # construct a CommonInlineUserDefinedFunction + fun = CommonInlineUserDefinedFunction( + function_name=name, + deterministic=deterministic, + arguments=[], + function=py_udf, + ).to_command(self) + + # construct the request + req = self._execute_plan_request_with_metadata() + req.plan.command.register_function.CopyFrom(fun) + + self._execute(req) + return name + def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]: return [ PlanMetrics( diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index dcd7c5ebba6b2..28b796496eca8 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -542,6 +542,13 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": ) return expr + def to_command(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction": + expr = proto.CommonInlineUserDefinedFunction() + expr.function_name = self._function_name + expr.deterministic = self._deterministic + expr.python_udf.CopyFrom(self._function.to_plan(session)) + return expr + def __repr__(self) -> str: return ( f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])}), " diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index c9a51b04bb6d7..f7e9260212e68 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcc\x02\n\x07\x43ommand\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xe6\x05\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12\x1f\n\ntable_name\x18\x04 \x01(\tH\x00R\ttableName\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_type"\x9b\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1a\n\x08provider\x18\x03 \x01(\tR\x08provider\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xe6\x05\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12\x1f\n\ntable_name\x18\x04 \x01(\tH\x00R\ttableName\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_type"\x9b\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1a\n\x08provider\x18\x03 \x01(\tR\x08provider\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -147,23 +147,23 @@ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001" _COMMAND._serialized_start = 166 - _COMMAND._serialized_end = 498 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 501 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 651 - _WRITEOPERATION._serialized_start = 654 - _WRITEOPERATION._serialized_end = 1396 - _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1092 - _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1150 - _WRITEOPERATION_BUCKETBY._serialized_start = 1152 - _WRITEOPERATION_BUCKETBY._serialized_end = 1243 - _WRITEOPERATION_SAVEMODE._serialized_start = 1246 - _WRITEOPERATION_SAVEMODE._serialized_end = 1383 - _WRITEOPERATIONV2._serialized_start = 1399 - _WRITEOPERATIONV2._serialized_end = 2194 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1092 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1150 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 1966 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2032 - _WRITEOPERATIONV2_MODE._serialized_start = 2035 - _WRITEOPERATIONV2_MODE._serialized_end = 2194 + _COMMAND._serialized_end = 593 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746 + _WRITEOPERATION._serialized_start = 749 + _WRITEOPERATION._serialized_end = 1491 + _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187 + _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245 + _WRITEOPERATION_BUCKETBY._serialized_start = 1247 + _WRITEOPERATION_BUCKETBY._serialized_end = 1338 + _WRITEOPERATION_SAVEMODE._serialized_start = 1341 + _WRITEOPERATION_SAVEMODE._serialized_end = 1478 + _WRITEOPERATIONV2._serialized_start = 1494 + _WRITEOPERATIONV2._serialized_end = 2289 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127 + _WRITEOPERATIONV2_MODE._serialized_start = 2130 + _WRITEOPERATIONV2_MODE._serialized_end = 2289 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 8a1f2ffb12257..4bdf1f1ed4ed3 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -59,11 +59,16 @@ class Command(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + REGISTER_FUNCTION_FIELD_NUMBER: builtins.int WRITE_OPERATION_FIELD_NUMBER: builtins.int CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property + def register_function( + self, + ) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: ... + @property def write_operation(self) -> global___WriteOperation: ... @property def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ... @@ -77,6 +82,8 @@ class Command(google.protobuf.message.Message): def __init__( self, *, + register_function: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction + | None = ..., write_operation: global___WriteOperation | None = ..., create_dataframe_view: global___CreateDataFrameViewCommand | None = ..., write_operation_v2: global___WriteOperationV2 | None = ..., @@ -91,6 +98,8 @@ class Command(google.protobuf.message.Message): b"create_dataframe_view", "extension", b"extension", + "register_function", + b"register_function", "write_operation", b"write_operation", "write_operation_v2", @@ -106,6 +115,8 @@ class Command(google.protobuf.message.Message): b"create_dataframe_view", "extension", b"extension", + "register_function", + b"register_function", "write_operation", b"write_operation", "write_operation_v2", @@ -115,7 +126,11 @@ class Command(google.protobuf.message.Message): def WhichOneof( self, oneof_group: typing_extensions.Literal["command_type", b"command_type"] ) -> typing_extensions.Literal[ - "write_operation", "create_dataframe_view", "write_operation_v2", "extension" + "register_function", + "write_operation", + "create_dataframe_view", + "write_operation_v2", + "extension", ] | None: ... global___Command = Command diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 898baa45b03ce..3c44d06bb1c58 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -68,6 +68,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import OptionalPrimitiveType from pyspark.sql.connect.catalog import Catalog + from pyspark.sql.connect.udf import UDFRegistration class SparkSession: @@ -436,8 +437,12 @@ def readStream(self) -> Any: raise NotImplementedError("readStream() is not implemented.") @property - def udf(self) -> Any: - raise NotImplementedError("udf() is not implemented.") + def udf(self) -> "UDFRegistration": + from pyspark.sql.connect.udf import UDFRegistration + + return UDFRegistration(self) + + udf.__doc__ = PySparkSession.udf.__doc__ @property def version(self) -> str: diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 6571cf7692983..573d8f582e287 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -23,8 +23,9 @@ import sys import functools -from typing import Callable, Any, TYPE_CHECKING, Optional +from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union +from pyspark.rdd import PythonEvalType from pyspark.serializers import CloudPickleSerializer from pyspark.sql.connect.expressions import ( ColumnReference, @@ -33,6 +34,7 @@ ) from pyspark.sql.connect.column import Column from pyspark.sql.types import DataType, StringType +from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration from pyspark.sql.utils import is_remote @@ -42,6 +44,7 @@ DataTypeOrString, UserDefinedFunctionLike, ) + from pyspark.sql.connect.session import SparkSession from pyspark.sql.types import StringType @@ -75,7 +78,7 @@ def __init__( func: Callable[..., Any], returnType: "DataTypeOrString" = StringType(), name: Optional[str] = None, - evalType: int = 100, + evalType: int = PythonEvalType.SQL_BATCHED_UDF, deterministic: bool = True, ): if not callable(func): @@ -187,3 +190,54 @@ def asNondeterministic(self) -> "UserDefinedFunction": """ self.deterministic = False return self + + +class UDFRegistration: + """ + Wrapper for user-defined function registration. + """ + + def __init__(self, sparkSession: "SparkSession"): + self.sparkSession = sparkSession + + def register( + self, + name: str, + f: Union[Callable[..., Any], "UserDefinedFunctionLike"], + returnType: Optional["DataTypeOrString"] = None, + ) -> "UserDefinedFunctionLike": + # This is to check whether the input function is from a user-defined function or + # Python function. + if hasattr(f, "asNondeterministic"): + if returnType is not None: + raise TypeError( + "Invalid return type: data type can not be specified when f is" + "a user-defined function, but got %s." % returnType + ) + f = cast("UserDefinedFunctionLike", f) + if f.evalType not in [ + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + ]: + raise ValueError( + "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " + "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF." + ) + return_udf = f + self.sparkSession._client.register_udf( + f, f.returnType, name, f.evalType, f.deterministic + ) + else: + if returnType is None: + returnType = StringType() + return_udf = _create_udf( + f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name + ) + + self.sparkSession._client.register_udf(f, returnType, name) + + return return_udf + + register.__doc__ = PySparkUDFRegistration.register.__doc__ diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 942e1da95c8ac..38c93b2d0acc0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -709,15 +709,15 @@ def udf(self) -> "UDFRegistration": .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Returns ------- :class:`UDFRegistration` Examples -------- - >>> spark.udf - - Register a Python UDF, and use it in SQL. >>> strlen = spark.udf.register("strlen", lambda x: len(x)) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index eebfaaa39d841..b8e2c7b151a04 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2734,7 +2734,6 @@ def test_unsupported_session_functions(self): "sparkContext", "streams", "readStream", - "udf", "version", ): with self.assertRaises(NotImplementedError): diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py index 8d4bb69bf1633..5fe1dee7fe85d 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf.py @@ -19,7 +19,7 @@ from pyspark.testing.connectutils import should_test_connect -if should_test_connect: # test_udf_with_partial_function +if should_test_connect: from pyspark import sql from pyspark.sql.connect.udf import UserDefinedFunction @@ -27,6 +27,7 @@ from pyspark.sql.tests.test_udf import BaseUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.types import IntegerType class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase): @@ -149,11 +150,6 @@ def test_non_existed_udaf(self): def test_non_existed_udf(self): super().test_non_existed_udf() - # TODO(SPARK-42210): implement `spark.udf` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf_registration_returns_udf(self): - super().test_udf_registration_returns_udf() - # TODO(SPARK-42210): implement `spark.udf` @unittest.skip("Fails in Spark Connect, should enable.") def test_register_java_function(self): @@ -179,6 +175,15 @@ def test_udf_with_string_return_type(self): def test_udf_in_subquery(self): super().test_udf_in_subquery() + def test_udf_registration_returns_udf(self): + df = self.spark.range(10) + add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType()) + + self.assertListEqual( + df.selectExpr("add_three(id) AS plus_three").collect(), + df.select(add_three("id").alias("plus_three")).collect(), + ) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 79ae456b1f75c..9f8e3e469775b 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -477,6 +477,9 @@ def register( .. versionadded:: 1.3.1 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- name : str,