diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 5f59ada38b658..a521eab20d842 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -48,6 +48,11 @@ message Request { // The logical plan to be executed / analyzed. Plan plan = 3; + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 4; + // User Context is used to refer to one particular user session that is executing // queries in the backend. message UserContext { diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 629497201344d..3c3203a8f514d 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -301,9 +301,7 @@ def register_udf( fun.parts.append(name) fun.serialized_function = cloudpickle.dumps((function, return_type)) - req = pb2.Request() - if self._user_id is not None: - req.user_context.user_id = self._user_id + req = self._request_with_metadata() req.plan.command.create_function.CopyFrom(fun) self._execute_and_fetch(req) @@ -357,9 +355,7 @@ def range( ) def _to_pandas(self, plan: pb2.Plan) -> Optional[pandas.DataFrame]: - req = pb2.Request() - if self._user_id is not None: - req.user_context.user_id = self._user_id + req = self._request_with_metadata() req.plan.CopyFrom(plan) return self._execute_and_fetch(req) @@ -407,12 +403,16 @@ def execute_command(self, command: pb2.Command) -> None: req.plan.command.CopyFrom(command) self._execute_and_fetch(req) - def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: + def _request_with_metadata(self) -> pb2.Request: req = pb2.Request() + req.client_type = "_SPARK_CONNECT_PYTHON" if self._user_id: req.user_context.user_id = self._user_id - req.plan.CopyFrom(plan) + return req + def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: + req = self._request_with_metadata() + req.plan.CopyFrom(plan) resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 1f577089d1a29..0527e9b49aa86 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xc8\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensionsB\x0e\n\x0c_client_type"\xe0\x06\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\narrowBatch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a=\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x86\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x12%\n\x0e\x65xplain_string\x18\x03 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -49,25 +49,25 @@ _PLAN._serialized_start = 158 _PLAN._serialized_end = 274 _REQUEST._serialized_start = 277 - _REQUEST._serialized_end = 551 - _REQUEST_USERCONTEXT._serialized_start = 429 - _REQUEST_USERCONTEXT._serialized_end = 551 - _RESPONSE._serialized_start = 554 - _RESPONSE._serialized_end = 1418 - _RESPONSE_ARROWBATCH._serialized_start = 793 - _RESPONSE_ARROWBATCH._serialized_end = 854 - _RESPONSE_JSONBATCH._serialized_start = 856 - _RESPONSE_JSONBATCH._serialized_end = 916 - _RESPONSE_METRICS._serialized_start = 919 - _RESPONSE_METRICS._serialized_end = 1403 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1003 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1313 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1201 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1313 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1315 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1403 - _ANALYZERESPONSE._serialized_start = 1421 - _ANALYZERESPONSE._serialized_end = 1555 - _SPARKCONNECTSERVICE._serialized_start = 1558 - _SPARKCONNECTSERVICE._serialized_end = 1720 + _REQUEST._serialized_end = 605 + _REQUEST_USERCONTEXT._serialized_start = 467 + _REQUEST_USERCONTEXT._serialized_end = 589 + _RESPONSE._serialized_start = 608 + _RESPONSE._serialized_end = 1472 + _RESPONSE_ARROWBATCH._serialized_start = 847 + _RESPONSE_ARROWBATCH._serialized_end = 908 + _RESPONSE_JSONBATCH._serialized_start = 910 + _RESPONSE_JSONBATCH._serialized_end = 970 + _RESPONSE_METRICS._serialized_start = 973 + _RESPONSE_METRICS._serialized_end = 1457 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1057 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1367 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1255 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1367 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1369 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1457 + _ANALYZERESPONSE._serialized_start = 1475 + _ANALYZERESPONSE._serialized_end = 1609 + _SPARKCONNECTSERVICE._serialized_start = 1612 + _SPARKCONNECTSERVICE._serialized_end = 1774 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index bf6d080d9fd97..e70f9db14a368 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -135,6 +135,7 @@ class Request(google.protobuf.message.Message): CLIENT_ID_FIELD_NUMBER: builtins.int USER_CONTEXT_FIELD_NUMBER: builtins.int PLAN_FIELD_NUMBER: builtins.int + CLIENT_TYPE_FIELD_NUMBER: builtins.int client_id: builtins.str """The client_id is set by the client to be able to collate streaming responses from different queries. @@ -145,23 +146,50 @@ class Request(google.protobuf.message.Message): @property def plan(self) -> global___Plan: """The logical plan to be executed / analyzed.""" + client_type: builtins.str + """Provides optional information about the client sending the request. This field + can be used for language or version specific information and is only intended for + logging purposes and will not be interpreted by the server. + """ def __init__( self, *, client_id: builtins.str = ..., user_context: global___Request.UserContext | None = ..., plan: global___Plan | None = ..., + client_type: builtins.str | None = ..., ) -> None: ... def HasField( self, - field_name: typing_extensions.Literal["plan", b"plan", "user_context", b"user_context"], + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "plan", + b"plan", + "user_context", + b"user_context", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "client_id", b"client_id", "plan", b"plan", "user_context", b"user_context" + "_client_type", + b"_client_type", + "client_id", + b"client_id", + "client_type", + b"client_type", + "plan", + b"plan", + "user_context", + b"user_context", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] + ) -> typing_extensions.Literal["client_type"] | None: ... global___Request = Request